Compare commits

...

33 Commits

Author SHA1 Message Date
sinu
a56f2dd02b rename ProvePayload to ProveRequest 2025-09-11 12:51:23 -07:00
sinu
38a1ec3f72 import order 2025-09-11 09:43:38 -07:00
sinu
1501bc661f consolidate encoding stuff 2025-09-11 09:42:03 -07:00
sinu
0e2c2cb045 revert additional derives 2025-09-11 09:37:44 -07:00
sinu
a662fb7511 fix import order 2025-09-11 09:32:50 -07:00
th4s
be2e1ab95a adapt tests to new base 2025-09-11 18:15:14 +02:00
th4s
d34d135bfe adapt to new base 2025-09-11 10:55:26 +02:00
th4s
091d26bb63 fmt nightly 2025-09-11 10:13:08 +02:00
th4s
f031fe9a8d allow encoding protocol being executed only once 2025-09-11 10:13:08 +02:00
th4s
12c9a5eb34 refactor(tlsn): improve proving flow prover and verifier
- add `TlsTranscript` fixture for testing
- refactor and simplify the commit flow
- change how the TLS transcript is committed
  - tag verification is still done for the whole transcript
  - plaintext authentication is only done where needed
  - encoding adjustments are only transmitted for needed ranges
  - makes `prove` and `verify` functions more efficient and callable more than once
- decoding now supports server-write-key decoding without authentication
- add more tests
2025-09-11 10:13:06 +02:00
sinu.eth
b4380f021e refactor: decouple ProveConfig from PartialTranscript (#991) 2025-09-11 09:13:52 +02:00
sinu.eth
8a823d18ec refactor(core): replace Idx with RangeSet (#988)
* refactor(core): replace Idx with RangeSet

* clippy
2025-09-10 15:44:40 -07:00
sinu.eth
7bcfc56bd8 fix(tls-core): remove deprecated webpki error variants (#992)
* fix(tls-core): remove deprecated webpki error variants

* clippy
2025-09-10 15:24:07 -07:00
sinu.eth
2909d5ebaa chore: bump mpz to 3d90b6c (#990) 2025-09-10 14:38:48 -07:00
sinu.eth
7918494ccc fix(core): fix dev dependencies (#989) 2025-09-10 14:25:04 -07:00
sinu.eth
92dd47b376 fix(core): enable zeroize derive (#987) 2025-09-10 14:11:41 -07:00
th4s
5474a748ce feat(core): Add transcript fixture (#983)
* feat(core): add transcript fixture for testing

* add feedback

* remove packages from dev dependencies
2025-09-10 22:58:10 +02:00
yuroitaki
92da5adc24 chore: update attestation example (#966)
* Add attestation example.

* Apply fmt.

* Apply clippy fix.

* Rebase.

* Improved readme + more default loggging in prove example

* Removed wrong AI generated "learn more" links

* re-export ContentType in tlsn-core

* remove unnecessary checks from example

---------

Co-authored-by: yuroitaki <>
Co-authored-by: Hendrik Eeckhaut <hendrik@eeckhaut.org>
Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>
2025-09-10 09:37:17 -07:00
Hendrik Eeckhaut
e0ce1ad31a build:Update to unpatched ws_stream_wasm crate (#975) 2025-09-01 16:33:00 +02:00
Hendrik Eeckhaut
3b76877920 build: reduce wasm size (#977) 2025-09-01 11:28:12 +02:00
Hendrik Eeckhaut
783355772a docs: corrected commands in docker.md of the harness (#976) 2025-08-28 17:00:18 +02:00
dan
e5c59da90b chore: fix tests (#974) 2025-08-26 08:42:48 +00:00
dan
f059c53c2d use zk config; bump mpz (#973) 2025-08-26 08:23:24 +00:00
sinu.eth
a1367b5428 refactor(tlsn): change network setting default to reduce data transfer (#971) 2025-08-22 14:00:23 -07:00
sinu.eth
9d8124ac9d chore: bump mpz to 1b00912 (#970) 2025-08-21 09:46:29 -07:00
dan
5034366c72 fix(hmac-sha256): compute PHash and AHash concurrently (#969)
---------

Co-authored-by: th4s <th4s@metavoid.xyz>
2025-08-21 06:41:59 +00:00
sinu.eth
afd8f44261 feat(tlsn): serializable config (#968) 2025-08-18 09:03:04 -07:00
sinu.eth
21086d2883 refactor: clean up web pki (#967)
* refactor: clean up web pki

* fix time import

* clippy

* fix wasm
2025-08-18 08:36:04 -07:00
dan
cca9a318a4 fix(harness): improve harness stability (#962) 2025-08-15 09:17:20 +00:00
dan
cb804a6025 fix(harness): disable tracing events (#961) 2025-08-15 07:13:12 +00:00
th4s
9f849e7c18 fix(encoding): set correct frame limit (#963)
* fix(encoding): set correct frame limit

* bugfix for `TranscriptRefs::len`

* use current frame limit as cushion room
2025-08-13 09:57:00 +02:00
th4s
389bceddef chore: bump rust version, fix lints and satisfy clippy (#964)
* chore(lints): fix lints and satisfy clippy

* bump rust version in ci
2025-08-12 10:50:31 -07:00
th4s
657838671a chore: remove notarize methods for prover and verifier (#952)
* feat: remove notarize methods for prover and verifier

* clean up imports

* remove remaining notarize methods

* clean up imports

* remove wasm attestation bindings

---------

Co-authored-by: sinu <65924192+sinui0@users.noreply.github.com>
2025-08-06 09:38:43 -07:00
107 changed files with 6518 additions and 4601 deletions

View File

@@ -21,7 +21,7 @@ env:
# - https://github.com/privacy-scaling-explorations/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.88.0
RUST_VERSION: 1.89.0
jobs:
clippy:
@@ -198,4 +198,4 @@ jobs:
draft: true
tag_name: ${{ github.ref_name }}
prerelease: true
generate_release_notes: true
generate_release_notes: true

1138
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -39,6 +39,8 @@ opt-level = 1
[profile.wasm]
inherits = "release"
lto = true
panic = "abort"
codegen-units = 1
[workspace.dependencies]
tls-server-fixture = { path = "crates/tls/server-fixture" }
@@ -64,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
@@ -84,6 +86,7 @@ spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
aead = { version = "0.4" }
aes = { version = "0.8" }
aes-gcm = { version = "0.9" }
anyhow = { version = "1.0" }
@@ -97,7 +100,7 @@ bytes = { version = "1.4" }
cfg-if = { version = "1" }
chromiumoxide = { version = "0.7" }
chrono = { version = "0.4" }
cipher = { version = "0.4" }
cipher-crypto = { package = "cipher", version = "0.4" }
clap = { version = "4.5" }
criterion = { version = "0.5" }
ctr = { version = "0.9" }
@@ -120,6 +123,7 @@ inventory = { version = "0.3" }
itybity = { version = "0.2" }
js-sys = { version = "0.3" }
k256 = { version = "0.13" }
lipsum = { version = "0.9" }
log = { version = "0.4" }
once_cell = { version = "1.19" }
opaque-debug = { version = "0.3" }
@@ -137,6 +141,8 @@ rs_merkle = { git = "https://github.com/tlsnotary/rs-merkle.git", rev = "85f3e82
rstest = { version = "0.17" }
rustls = { version = "0.21" }
rustls-pemfile = { version = "1.0" }
rustls-webpki = { version = "0.103" }
rustls-pki-types = { version = "1.12" }
sct = { version = "0.7" }
semver = { version = "1.0" }
serde = { version = "1.0" }
@@ -157,7 +163,7 @@ wasm-bindgen = { version = "0.2" }
wasm-bindgen-futures = { version = "0.4" }
web-spawn = { version = "0.2" }
web-time = { version = "0.2" }
webpki = { version = "0.22" }
webpki-roots = { version = "0.26" }
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
ws_stream_wasm = { git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51" }
webpki-roots = { version = "1.0" }
webpki-root-certs = { version = "1.0" }
ws_stream_wasm = { version = "0.7.5" }
zeroize = { version = "1.8" }

View File

@@ -21,7 +21,6 @@ rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
webpki-roots = { workspace = true }
[dev-dependencies]
alloy-primitives = { version = "0.8.22", default-features = false }

View File

@@ -242,7 +242,7 @@ impl std::fmt::Display for AttestationBuilderError {
mod test {
use rstest::{fixture, rstest};
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::{ConnectionFixture, encoding_provider},
hash::Blake3,
transcript::Transcript,
@@ -399,10 +399,10 @@ mod test {
server_cert_data, ..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
};
@@ -470,10 +470,10 @@ mod test {
..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
};

View File

@@ -22,7 +22,7 @@
use serde::{Deserialize, Serialize};
use tlsn_core::{
connection::{CertificateVerificationError, ServerCertData, ServerEphemKey, ServerName},
connection::{HandshakeData, HandshakeVerificationError, ServerEphemKey, ServerName},
hash::{Blinded, HashAlgorithm, HashProviderError, TypedHash},
};
@@ -30,14 +30,14 @@ use crate::{CryptoProvider, hash::HashAlgorithmExt, serialize::impl_domain_separ
/// Opens a [`ServerCertCommitment`].
#[derive(Clone, Serialize, Deserialize)]
pub struct ServerCertOpening(Blinded<ServerCertData>);
pub struct ServerCertOpening(Blinded<HandshakeData>);
impl_domain_separator!(ServerCertOpening);
opaque_debug::implement!(ServerCertOpening);
impl ServerCertOpening {
pub(crate) fn new(data: ServerCertData) -> Self {
pub(crate) fn new(data: HandshakeData) -> Self {
Self(Blinded::new(data))
}
@@ -49,7 +49,7 @@ impl ServerCertOpening {
}
/// Returns the server identity data.
pub fn data(&self) -> &ServerCertData {
pub fn data(&self) -> &HandshakeData {
self.0.data()
}
}
@@ -122,8 +122,8 @@ impl From<HashProviderError> for ServerIdentityProofError {
}
}
impl From<CertificateVerificationError> for ServerIdentityProofError {
fn from(err: CertificateVerificationError) -> Self {
impl From<HandshakeVerificationError> for ServerIdentityProofError {
fn from(err: HandshakeVerificationError) -> Self {
Self {
kind: ErrorKind::Certificate,
message: err.to_string(),

View File

@@ -1,7 +1,7 @@
//! Attestation fixtures.
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
hash::HashAlgorithm,
transcript::{
@@ -67,7 +67,7 @@ pub fn request_fixture(
let mut request_builder = Request::builder(&request_config);
request_builder
.server_name(server_name)
.server_cert_data(server_cert_data)
.handshake_data(server_cert_data)
.transcript(transcript);
let (request, _) = request_builder.build(&provider).unwrap();
@@ -91,12 +91,12 @@ pub fn attestation_fixture(
..
} = connection;
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake
}) = server_cert_data.binding
else {
panic!("expected v1.2 handshake data");
panic!("expected v1.2 binding data");
};
let mut provider = CryptoProvider::default();

View File

@@ -1,8 +1,4 @@
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
verify::WebPkiVerifier,
};
use tlsn_core::hash::HashProvider;
use tlsn_core::{hash::HashProvider, webpki::ServerCertVerifier};
use crate::signing::{SignatureVerifierProvider, SignerProvider};
@@ -28,7 +24,7 @@ pub struct CryptoProvider {
/// This is used to verify the server's certificate chain.
///
/// The default verifier uses the Mozilla root certificates.
pub cert: WebPkiVerifier,
pub cert: ServerCertVerifier,
/// Signer provider.
///
/// This is used for signing attestations.
@@ -45,21 +41,9 @@ impl Default for CryptoProvider {
fn default() -> Self {
Self {
hash: Default::default(),
cert: default_cert_verifier(),
cert: ServerCertVerifier::mozilla(),
signer: Default::default(),
signature: Default::default(),
}
}
}
pub(crate) fn default_cert_verifier() -> WebPkiVerifier {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
WebPkiVerifier::new(root_store, None)
}

View File

@@ -36,7 +36,7 @@ pub struct Request {
impl Request {
/// Returns a new request builder.
pub fn builder(config: &RequestConfig) -> RequestBuilder {
pub fn builder(config: &RequestConfig) -> RequestBuilder<'_> {
RequestBuilder::new(config)
}

View File

@@ -1,5 +1,5 @@
use tlsn_core::{
connection::{ServerCertData, ServerName},
connection::{HandshakeData, ServerName},
transcript::{Transcript, TranscriptCommitment, TranscriptSecret},
};
@@ -13,7 +13,7 @@ use crate::{
pub struct RequestBuilder<'a> {
config: &'a RequestConfig,
server_name: Option<ServerName>,
server_cert_data: Option<ServerCertData>,
handshake_data: Option<HandshakeData>,
transcript: Option<Transcript>,
transcript_commitments: Vec<TranscriptCommitment>,
transcript_commitment_secrets: Vec<TranscriptSecret>,
@@ -25,7 +25,7 @@ impl<'a> RequestBuilder<'a> {
Self {
config,
server_name: None,
server_cert_data: None,
handshake_data: None,
transcript: None,
transcript_commitments: Vec::new(),
transcript_commitment_secrets: Vec::new(),
@@ -38,9 +38,9 @@ impl<'a> RequestBuilder<'a> {
self
}
/// Sets the server identity data.
pub fn server_cert_data(&mut self, data: ServerCertData) -> &mut Self {
self.server_cert_data = Some(data);
/// Sets the handshake data.
pub fn handshake_data(&mut self, data: HandshakeData) -> &mut Self {
self.handshake_data = Some(data);
self
}
@@ -69,7 +69,7 @@ impl<'a> RequestBuilder<'a> {
let Self {
config,
server_name,
server_cert_data,
handshake_data: server_cert_data,
transcript,
transcript_commitments,
transcript_commitment_secrets,

View File

@@ -46,7 +46,7 @@ pub(crate) use impl_domain_separator;
impl_domain_separator!(tlsn_core::connection::ServerEphemKey);
impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
impl_domain_separator!(tlsn_core::connection::HandshakeData);
impl_domain_separator!(tlsn_core::connection::CertBinding);
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);

View File

@@ -5,7 +5,7 @@ use tlsn_attestation::{
signing::SignatureAlgId,
};
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2},
connection::{CertBinding, CertBindingV1_2},
fixtures::{self, ConnectionFixture, encoder_secret},
hash::Blake3,
transcript::{
@@ -36,10 +36,10 @@ fn test_api() {
server_cert_data,
} = ConnectionFixture::tlsnotary(transcript.length());
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = server_cert_data.handshake.clone()
}) = server_cert_data.binding.clone()
else {
unreachable!()
};
@@ -72,7 +72,7 @@ fn test_api() {
request_builder
.server_name(server_name.clone())
.server_cert_data(server_cert_data)
.handshake_data(server_cert_data)
.transcript(transcript)
.transcript_commitments(
vec![TranscriptSecret::Encoding(encoding_tree)],

View File

@@ -31,4 +31,4 @@ mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }
cipher-crypto = { workspace = true }

View File

@@ -344,8 +344,8 @@ mod tests {
start_ctr: usize,
msg: Vec<u8>,
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let mut full_iv = [0u8; 16];
@@ -365,7 +365,7 @@ mod tests {
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
use cipher_crypto::{BlockEncrypt, KeyInit};
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());

View File

@@ -391,7 +391,7 @@ mod tests {
memory::{binary::U8, correlated::Delta, Array},
prelude::*,
};
use mpz_zk::{Prover, Verifier};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
@@ -408,8 +408,8 @@ mod tests {
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta_zk, rcot_send);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
@@ -488,8 +488,8 @@ mod tests {
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta_zk, rcot_send);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
@@ -574,8 +574,8 @@ mod tests {
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta_zk, rcot_send);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);

View File

@@ -40,7 +40,6 @@ enum PrfState {
inner_partial: [u32; 8],
a_output: DecodeFutureTyped<BitVec, [u8; 32]>,
},
FinishLastP,
Done,
}
@@ -137,16 +136,18 @@ impl PrfFunction {
assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?;
if *iter == self.iterations {
self.state = PrfState::FinishLastP;
self.state = PrfState::Done;
} else {
self.state = PrfState::ComputeA {
iter: *iter + 1,
inner_partial: *inner_partial,
msg: output.to_vec(),
}
};
};
// We recurse, so that this PHash and the next AHash could
// be computed in a single VM execute call.
self.flush(vm)?;
}
}
PrfState::FinishLastP => self.state = PrfState::Done,
_ => (),
}

View File

@@ -13,7 +13,13 @@ workspace = true
[features]
default = []
fixtures = ["dep:hex", "dep:tlsn-data-fixtures"]
fixtures = [
"dep:hex",
"dep:tlsn-data-fixtures",
"dep:aead",
"dep:aes-gcm",
"dep:generic-array",
]
[dependencies]
tlsn-data-fixtures = { workspace = true, optional = true }
@@ -21,6 +27,9 @@ tlsn-tls-core = { workspace = true, features = ["serde"] }
tlsn-utils = { workspace = true }
rangeset = { workspace = true, features = ["serde"] }
aead = { workspace = true, features = ["alloc"], optional = true }
aes-gcm = { workspace = true, optional = true }
generic-array = { workspace = true, optional = true }
bimap = { version = "0.6", features = ["serde"] }
blake3 = { workspace = true }
hex = { workspace = true, optional = true }
@@ -36,10 +45,17 @@ thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
web-time = { workspace = true }
webpki-roots = { workspace = true }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }
itybity = { workspace = true }
zeroize = { workspace = true, features = ["zeroize_derive"] }
[dev-dependencies]
aead = { workspace = true, features = ["alloc"] }
aes-gcm = { workspace = true }
generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -2,16 +2,14 @@
use std::fmt;
use rustls_pki_types as webpki_types;
use serde::{Deserialize, Serialize};
use tls_core::{
msgs::{
codec::Codec,
enums::NamedGroup,
handshake::{DigitallySignedStruct, ServerECDHParams},
},
verify::{ServerCertVerifier as _, WebPkiVerifier},
use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams};
use crate::{
transcript::TlsTranscript,
webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError},
};
use web_time::{Duration, UNIX_EPOCH};
/// TLS version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -35,40 +33,82 @@ impl TryFrom<tls_core::msgs::enums::ProtocolVersion> for TlsVersion {
}
}
/// Server's name, a.k.a. the DNS name.
/// Server's name.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ServerName(String);
pub enum ServerName {
/// DNS name.
Dns(DnsName),
}
impl ServerName {
/// Creates a new server name.
pub fn new(name: String) -> Self {
Self(name)
}
/// Returns the name as a string.
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for ServerName {
fn from(name: &str) -> Self {
Self(name.to_string())
}
}
impl AsRef<str> for ServerName {
fn as_ref(&self) -> &str {
&self.0
pub(crate) fn to_webpki(&self) -> webpki_types::ServerName<'static> {
match self {
ServerName::Dns(name) => webpki_types::ServerName::DnsName(
webpki_types::DnsName::try_from(name.0.as_str())
.expect("name was validated")
.to_owned(),
),
}
}
}
impl fmt::Display for ServerName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ServerName::Dns(name) => write!(f, "{name}"),
}
}
}
/// DNS name.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(try_from = "String")]
pub struct DnsName(String);
impl DnsName {
/// Returns the DNS name as a string.
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for DnsName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for DnsName {
fn as_ref(&self) -> &str {
&self.0
}
}
/// Error returned when a DNS name is invalid.
#[derive(Debug, thiserror::Error)]
#[error("invalid DNS name")]
pub struct InvalidDnsNameError {}
impl TryFrom<&str> for DnsName {
type Error = InvalidDnsNameError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
// Borrow validation from rustls
match webpki_types::DnsName::try_from_str(value) {
Ok(_) => Ok(DnsName(value.to_string())),
Err(_) => Err(InvalidDnsNameError {}),
}
}
}
impl TryFrom<String> for DnsName {
type Error = InvalidDnsNameError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
/// Type of a public key.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
@@ -98,6 +138,25 @@ pub enum SignatureScheme {
ED25519 = 0x0807,
}
impl fmt::Display for SignatureScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SignatureScheme::RSA_PKCS1_SHA1 => write!(f, "RSA_PKCS1_SHA1"),
SignatureScheme::ECDSA_SHA1_Legacy => write!(f, "ECDSA_SHA1_Legacy"),
SignatureScheme::RSA_PKCS1_SHA256 => write!(f, "RSA_PKCS1_SHA256"),
SignatureScheme::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureScheme::RSA_PKCS1_SHA384 => write!(f, "RSA_PKCS1_SHA384"),
SignatureScheme::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureScheme::RSA_PKCS1_SHA512 => write!(f, "RSA_PKCS1_SHA512"),
SignatureScheme::ECDSA_NISTP521_SHA512 => write!(f, "ECDSA_NISTP521_SHA512"),
SignatureScheme::RSA_PSS_SHA256 => write!(f, "RSA_PSS_SHA256"),
SignatureScheme::RSA_PSS_SHA384 => write!(f, "RSA_PSS_SHA384"),
SignatureScheme::RSA_PSS_SHA512 => write!(f, "RSA_PSS_SHA512"),
SignatureScheme::ED25519 => write!(f, "ED25519"),
}
}
}
impl TryFrom<tls_core::msgs::enums::SignatureScheme> for SignatureScheme {
type Error = &'static str;
@@ -142,16 +201,6 @@ impl From<SignatureScheme> for tls_core::msgs::enums::SignatureScheme {
}
}
/// X.509 certificate, DER encoded.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Certificate(pub Vec<u8>);
impl From<tls_core::key::Certificate> for Certificate {
fn from(cert: tls_core::key::Certificate) -> Self {
Self(cert.0)
}
}
/// Server's signature of the key exchange parameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSignature {
@@ -220,9 +269,9 @@ pub struct TranscriptLength {
pub received: u32,
}
/// TLS 1.2 handshake data.
/// TLS 1.2 certificate binding.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeDataV1_2 {
pub struct CertBindingV1_2 {
/// Client random.
pub client_random: [u8; 32],
/// Server random.
@@ -231,13 +280,18 @@ pub struct HandshakeDataV1_2 {
pub server_ephemeral_key: ServerEphemKey,
}
/// TLS handshake data.
/// TLS certificate binding.
///
/// This is the data that the server signs using its public key in the
/// certificate it presents during the TLS handshake. This provides a binding
/// between the server's identity and the ephemeral keys used to authenticate
/// the TLS session.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum HandshakeData {
/// TLS 1.2 handshake data.
V1_2(HandshakeDataV1_2),
pub enum CertBinding {
/// TLS 1.2 certificate binding.
V1_2(CertBindingV1_2),
}
/// Verify data from the TLS handshake finished messages.
@@ -249,19 +303,38 @@ pub struct VerifyData {
pub server_finished: Vec<u8>,
}
/// Server certificate and handshake data.
/// TLS handshake data.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCertData {
/// Certificate chain.
pub certs: Vec<Certificate>,
/// Server signature of the key exchange parameters.
pub struct HandshakeData {
/// Server certificate chain.
pub certs: Vec<CertificateDer>,
/// Server certificate signature over the binding message.
pub sig: ServerSignature,
/// TLS handshake data.
pub handshake: HandshakeData,
/// Certificate binding.
pub binding: CertBinding,
}
impl ServerCertData {
/// Verifies the server certificate data.
impl HandshakeData {
/// Creates a new instance.
///
/// # Arguments
///
/// * `transcript` - The TLS transcript.
pub fn new(transcript: &TlsTranscript) -> Self {
Self {
certs: transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: transcript.certificate_binding().clone(),
}
}
/// Verifies the handshake data.
///
/// # Arguments
///
@@ -271,53 +344,35 @@ impl ServerCertData {
/// * `server_name` - The server name.
pub fn verify(
&self,
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
time: u64,
server_ephemeral_key: &ServerEphemKey,
server_name: &ServerName,
) -> Result<(), CertificateVerificationError> {
) -> Result<(), HandshakeVerificationError> {
#[allow(irrefutable_let_patterns)]
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
client_random,
server_random,
server_ephemeral_key: expected_server_ephemeral_key,
}) = &self.handshake
}) = &self.binding
else {
unreachable!("only TLS 1.2 is implemented")
};
if server_ephemeral_key != expected_server_ephemeral_key {
return Err(CertificateVerificationError::InvalidServerEphemeralKey);
return Err(HandshakeVerificationError::InvalidServerEphemeralKey);
}
// Verify server name.
let server_name = tls_core::dns::ServerName::try_from(server_name.as_ref())
.map_err(|_| CertificateVerificationError::InvalidIdentity(server_name.clone()))?;
// Verify server certificate.
let cert_chain = self
let (end_entity, intermediates) = self
.certs
.clone()
.into_iter()
.map(|cert| tls_core::key::Certificate(cert.0))
.collect::<Vec<_>>();
let (end_entity, intermediates) = cert_chain
.split_first()
.ok_or(CertificateVerificationError::MissingCerts)?;
.ok_or(HandshakeVerificationError::MissingCerts)?;
// Verify the end entity cert is valid for the provided server name
// and that it chains to at least one of the roots we trust.
verifier
.verify_server_cert(
end_entity,
intermediates,
&server_name,
&mut [].into_iter(),
&[],
UNIX_EPOCH + Duration::from_secs(time),
)
.map_err(|_| CertificateVerificationError::InvalidCert)?;
.verify_server_cert(end_entity, intermediates, server_name, time)
.map_err(HandshakeVerificationError::ServerCert)?;
// Verify the signature matches the certificate and key exchange parameters.
let mut message = Vec::new();
@@ -325,11 +380,31 @@ impl ServerCertData {
message.extend_from_slice(server_random);
message.extend_from_slice(&server_ephemeral_key.kx_params());
let dss = DigitallySignedStruct::new(self.sig.scheme.into(), self.sig.sig.clone());
use webpki::ring as alg;
let sig_alg = match self.sig.scheme {
SignatureScheme::RSA_PKCS1_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureScheme::RSA_PSS_SHA256 => alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA384 => alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA512 => alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
SignatureScheme::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureScheme::ED25519 => alg::ED25519,
scheme => {
return Err(HandshakeVerificationError::UnsupportedSignatureScheme(
scheme,
))
}
};
verifier
.verify_tls12_signature(&message, end_entity, &dss)
.map_err(|_| CertificateVerificationError::InvalidServerSignature)?;
let end_entity = webpki_types::CertificateDer::from(end_entity.0.as_slice());
let end_entity = webpki::EndEntityCert::try_from(&end_entity)
.map_err(|_| HandshakeVerificationError::InvalidEndEntityCertificate)?;
end_entity
.verify_signature(sig_alg, &message, &self.sig.sig)
.map_err(|_| HandshakeVerificationError::InvalidServerSignature)?;
Ok(())
}
@@ -338,58 +413,51 @@ impl ServerCertData {
/// Errors that can occur when verifying a certificate chain or signature.
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum CertificateVerificationError {
#[error("invalid server identity: {0}")]
InvalidIdentity(ServerName),
pub enum HandshakeVerificationError {
#[error("invalid end entity certificate")]
InvalidEndEntityCertificate,
#[error("missing server certificates")]
MissingCerts,
#[error("invalid server certificate")]
InvalidCert,
#[error("invalid server signature")]
InvalidServerSignature,
#[error("invalid server ephemeral key")]
InvalidServerEphemeralKey,
#[error("server certificate verification failed: {0}")]
ServerCert(ServerCertVerifierError),
#[error("unsupported signature scheme: {0}")]
UnsupportedSignatureScheme(SignatureScheme),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{fixtures::ConnectionFixture, transcript::Transcript};
use crate::{fixtures::ConnectionFixture, transcript::Transcript, webpki::RootCertStore};
use hex::FromHex;
use rstest::*;
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
verify::WebPkiVerifier,
};
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
#[fixture]
#[once]
fn verifier() -> WebPkiVerifier {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
fn verifier() -> ServerCertVerifier {
let mut root_store = RootCertStore {
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|c| CertificateDer(c.to_vec()))
.collect(),
};
// Add a cert which is no longer included in the Mozilla root store.
let cert = tls_core::key::Certificate(
root_store.roots.push(
appliedzkp()
.server_cert_data
.certs
.last()
.expect("chain is valid")
.0
.clone(),
);
root_store.add(&cert).unwrap();
WebPkiVerifier::new(root_store, None)
ServerCertVerifier::new(&root_store).unwrap()
}
fn tlsnotary() -> ConnectionFixture {
@@ -405,7 +473,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_sucess_ca_implicit(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the CA cert
@@ -417,7 +485,7 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
)
.is_ok());
}
@@ -428,7 +496,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_success_ca_explicit(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
assert!(data
@@ -437,7 +505,7 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
)
.is_ok());
}
@@ -447,7 +515,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_bad_time(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
// unix time when the cert chain was NOT valid
@@ -457,12 +525,12 @@ mod tests {
verifier,
bad_time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -471,7 +539,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_interm_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the CA cert
@@ -483,12 +551,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -498,7 +566,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_interm_cert_with_ca_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Remove the intermediate cert
@@ -508,12 +576,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -522,24 +590,24 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_bad_ee_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
let ee: &[u8] = include_bytes!("./fixtures/data/unknown/ee.der");
// Change the end entity cert
data.server_cert_data.certs[0] = Certificate(ee.to_vec());
data.server_cert_data.certs[0] = CertificateDer(ee.to_vec());
let err = data.server_cert_data.verify(
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -548,23 +616,23 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_sig_ke_params_fail_bad_client_random(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
let HandshakeData::V1_2(HandshakeDataV1_2 { client_random, .. }) =
&mut data.server_cert_data.handshake;
let CertBinding::V1_2(CertBindingV1_2 { client_random, .. }) =
&mut data.server_cert_data.binding;
client_random[31] = client_random[31].wrapping_add(1);
let err = data.server_cert_data.verify(
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerSignature
HandshakeVerificationError::InvalidServerSignature
));
}
@@ -573,7 +641,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_sig_ke_params_fail_bad_sig(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
data.server_cert_data.sig.sig[31] = data.server_cert_data.sig.sig[31].wrapping_add(1);
@@ -582,12 +650,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerSignature
HandshakeVerificationError::InvalidServerSignature
));
}
@@ -596,10 +664,10 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_check_dns_name_present_in_cert_fail_bad_host(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] data: ConnectionFixture,
) {
let bad_name = ServerName::from("badhost.com");
let bad_name = ServerName::Dns(DnsName::try_from("badhost.com").unwrap());
let err = data.server_cert_data.verify(
verifier,
@@ -610,7 +678,7 @@ mod tests {
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidCert
HandshakeVerificationError::ServerCert(_)
));
}
@@ -618,7 +686,7 @@ mod tests {
#[rstest]
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_invalid_ephemeral_key(verifier: &WebPkiVerifier, #[case] data: ConnectionFixture) {
fn test_invalid_ephemeral_key(verifier: &ServerCertVerifier, #[case] data: ConnectionFixture) {
let wrong_ephemeral_key = ServerEphemKey {
typ: KeyType::SECP256R1,
key: Vec::<u8>::from_hex(include_bytes!("./fixtures/data/unknown/pubkey")).unwrap(),
@@ -628,12 +696,12 @@ mod tests {
verifier,
data.connection_info.time,
&wrong_ephemeral_key,
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::InvalidServerEphemeralKey
HandshakeVerificationError::InvalidServerEphemeralKey
));
}
@@ -642,7 +710,7 @@ mod tests {
#[case::tlsnotary(tlsnotary())]
#[case::appliedzkp(appliedzkp())]
fn test_verify_cert_chain_fail_no_cert(
verifier: &WebPkiVerifier,
verifier: &ServerCertVerifier,
#[case] mut data: ConnectionFixture,
) {
// Empty certs
@@ -652,12 +720,12 @@ mod tests {
verifier,
data.connection_info.time,
data.server_ephemeral_key(),
&ServerName::from(data.server_name.as_ref()),
&data.server_name,
);
assert!(matches!(
err.unwrap_err(),
CertificateVerificationError::MissingCerts
HandshakeVerificationError::MissingCerts
));
}
}

View File

@@ -0,0 +1,16 @@
use rangeset::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter_ranges() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;
}
}
f.write_str("}")
}
}

View File

@@ -1,6 +1,7 @@
//! Fixtures for testing
mod provider;
pub mod transcript;
pub use provider::FixtureEncodingProvider;
@@ -8,13 +9,14 @@ use hex::FromHex;
use crate::{
connection::{
Certificate, ConnectionInfo, HandshakeData, HandshakeDataV1_2, KeyType, ServerCertData,
CertBinding, CertBindingV1_2, ConnectionInfo, DnsName, HandshakeData, KeyType,
ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength,
},
transcript::{
encoding::{EncoderSecret, EncodingProvider},
Transcript,
},
webpki::CertificateDer,
};
/// A fixture containing various TLS connection data.
@@ -23,24 +25,26 @@ use crate::{
pub struct ConnectionFixture {
pub server_name: ServerName,
pub connection_info: ConnectionInfo,
pub server_cert_data: ServerCertData,
pub server_cert_data: HandshakeData,
}
impl ConnectionFixture {
/// Returns a connection fixture for tlsnotary.org.
pub fn tlsnotary(transcript_length: TranscriptLength) -> Self {
ConnectionFixture {
server_name: ServerName::new("tlsnotary.org".to_string()),
server_name: ServerName::Dns(DnsName::try_from("tlsnotary.org").unwrap()),
connection_info: ConnectionInfo {
time: 1671637529,
version: TlsVersion::V1_2,
transcript_length,
},
server_cert_data: ServerCertData {
server_cert_data: HandshakeData {
certs: vec![
Certificate(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()),
Certificate(include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec()),
Certificate(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ee.der").to_vec()),
CertificateDer(
include_bytes!("fixtures/data/tlsnotary.org/inter.der").to_vec(),
),
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
],
sig: ServerSignature {
scheme: SignatureScheme::RSA_PKCS1_SHA256,
@@ -49,7 +53,7 @@ impl ConnectionFixture {
))
.unwrap(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
binding: CertBinding::V1_2(CertBindingV1_2 {
client_random: <[u8; 32]>::from_hex(include_bytes!(
"fixtures/data/tlsnotary.org/client_random"
))
@@ -73,17 +77,19 @@ impl ConnectionFixture {
/// Returns a connection fixture for appliedzkp.org.
pub fn appliedzkp(transcript_length: TranscriptLength) -> Self {
ConnectionFixture {
server_name: ServerName::new("appliedzkp.org".to_string()),
server_name: ServerName::Dns(DnsName::try_from("appliedzkp.org").unwrap()),
connection_info: ConnectionInfo {
time: 1671637529,
version: TlsVersion::V1_2,
transcript_length,
},
server_cert_data: ServerCertData {
server_cert_data: HandshakeData {
certs: vec![
Certificate(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()),
Certificate(include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec()),
Certificate(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ee.der").to_vec()),
CertificateDer(
include_bytes!("fixtures/data/appliedzkp.org/inter.der").to_vec(),
),
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
],
sig: ServerSignature {
scheme: SignatureScheme::ECDSA_NISTP256_SHA256,
@@ -92,7 +98,7 @@ impl ConnectionFixture {
))
.unwrap(),
},
handshake: HandshakeData::V1_2(HandshakeDataV1_2 {
binding: CertBinding::V1_2(CertBindingV1_2 {
client_random: <[u8; 32]>::from_hex(include_bytes!(
"fixtures/data/appliedzkp.org/client_random"
))
@@ -115,10 +121,10 @@ impl ConnectionFixture {
/// Returns the server_ephemeral_key fixture.
pub fn server_ephemeral_key(&self) -> &ServerEphemKey {
let HandshakeData::V1_2(HandshakeDataV1_2 {
let CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) = &self.server_cert_data.handshake;
}) = &self.server_cert_data.binding;
server_ephemeral_key
}
}

View File

@@ -0,0 +1,199 @@
//! Transcript fixtures for testing.
use aead::Payload as AeadPayload;
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
use generic_array::GenericArray;
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_core::msgs::{
base::Payload,
codec::Codec,
enums::{ContentType, HandshakeType, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
message::{OpaqueMessage, PlainMessage},
};
use crate::{
connection::{TranscriptLength, VerifyData},
fixtures::ConnectionFixture,
transcript::{Record, TlsTranscript},
};
/// The key used for encryption of the sent and received transcript.
pub const KEY: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
/// The iv used for encryption of the sent and received transcript.
pub const IV: [u8; 4] = [1, 3, 3, 7];
/// The record size in bytes.
pub const RECORD_SIZE: usize = 512;
/// Creates a transript fixture for testing.
pub fn transcript_fixture(sent: &[u8], recv: &[u8]) -> TlsTranscript {
TranscriptGenerator::new(KEY, IV).generate(sent, recv)
}
struct TranscriptGenerator {
key: [u8; 16],
iv: [u8; 4],
}
impl TranscriptGenerator {
fn new(key: [u8; 16], iv: [u8; 4]) -> Self {
Self { key, iv }
}
fn generate(&self, sent: &[u8], recv: &[u8]) -> TlsTranscript {
let mut rng = StdRng::from_seed([1; 32]);
let transcript_len = TranscriptLength {
sent: sent.len() as u32,
received: recv.len() as u32,
};
let tlsn = ConnectionFixture::tlsnotary(transcript_len);
let time = tlsn.connection_info.time;
let version = tlsn.connection_info.version;
let server_cert_chain = tlsn.server_cert_data.certs;
let server_signature = tlsn.server_cert_data.sig;
let cert_binding = tlsn.server_cert_data.binding;
let cf_vd: [u8; 12] = rng.random();
let sf_vd: [u8; 12] = rng.random();
let verify_data = VerifyData {
client_finished: cf_vd.to_vec(),
server_finished: sf_vd.to_vec(),
};
let sent = self.gen_records(cf_vd, sent);
let recv = self.gen_records(sf_vd, recv);
TlsTranscript::new(
time,
version,
Some(server_cert_chain),
Some(server_signature),
cert_binding,
verify_data,
sent,
recv,
)
.unwrap()
}
fn gen_records(&self, vd: [u8; 12], plaintext: &[u8]) -> Vec<Record> {
let mut records = Vec::new();
let handshake = self.gen_handshake(vd);
records.push(handshake);
for (seq, msg) in (1_u64..).zip(plaintext.chunks(RECORD_SIZE)) {
let record = self.gen_app_data(seq, msg);
records.push(record);
}
records
}
fn gen_app_data(&self, seq: u64, plaintext: &[u8]) -> Record {
assert!(
plaintext.len() <= 1 << 14,
"plaintext len per record must be smaller than 2^14 bytes"
);
let explicit_nonce: [u8; 8] = seq.to_be_bytes();
let msg = PlainMessage {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext),
};
let opaque = aes_gcm_encrypt(self.key, self.iv, seq, explicit_nonce, &msg);
let mut payload = opaque.payload.0;
let mut ciphertext = payload.split_off(8);
let tag = ciphertext.split_off(ciphertext.len() - 16);
Record {
seq,
typ: ContentType::ApplicationData,
plaintext: Some(plaintext.to_vec()),
explicit_nonce: explicit_nonce.to_vec(),
ciphertext,
tag: Some(tag),
}
}
fn gen_handshake(&self, vd: [u8; 12]) -> Record {
let seq = 0_u64;
let explicit_nonce = seq.to_be_bytes();
let mut plaintext = Vec::new();
let payload = Payload(vd.to_vec());
let hs_payload = HandshakePayload::Finished(payload);
let handshake_message = HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: hs_payload,
};
handshake_message.encode(&mut plaintext);
let msg = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext.clone()),
};
let opaque = aes_gcm_encrypt(self.key, self.iv, seq, explicit_nonce, &msg);
let mut payload = opaque.payload.0;
let mut ciphertext = payload.split_off(8);
let tag = ciphertext.split_off(ciphertext.len() - 16);
Record {
seq,
typ: ContentType::Handshake,
plaintext: Some(plaintext),
explicit_nonce: explicit_nonce.to_vec(),
ciphertext,
tag: Some(tag),
}
}
}
fn aes_gcm_encrypt(
key: [u8; 16],
iv: [u8; 4],
seq: u64,
explicit_nonce: [u8; 8],
msg: &PlainMessage,
) -> OpaqueMessage {
let mut aad = [0u8; 13];
aad[..8].copy_from_slice(&seq.to_be_bytes());
aad[8] = msg.typ.get_u8();
aad[9..11].copy_from_slice(&msg.version.get_u16().to_be_bytes());
aad[11..13].copy_from_slice(&(msg.payload.0.len() as u16).to_be_bytes());
let payload = AeadPayload {
msg: &msg.payload.0,
aad: &aad,
};
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();
// ciphertext will have the MAC appended
let ciphertext = cipher.encrypt(nonce, payload).unwrap();
// prepend the explicit nonce
let mut nonce_ct_mac = vec![0u8; 0];
nonce_ct_mac.extend(explicit_nonce.iter());
nonce_ct_mac.extend(ciphertext.iter());
OpaqueMessage {
typ: msg.typ,
version: msg.version,
payload: Payload::new(nonce_ct_mac),
}
}

View File

@@ -95,7 +95,7 @@ impl Display for HashAlgId {
}
/// A typed hash value.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TypedHash {
/// The algorithm of the hash.
pub alg: HashAlgId,

View File

@@ -10,29 +10,32 @@ pub mod fixtures;
pub mod hash;
pub mod merkle;
pub mod transcript;
pub mod webpki;
pub use rangeset;
pub(crate) mod display;
use rangeset::ToRangeSet;
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
connection::{ServerCertData, ServerName},
connection::{HandshakeData, ServerName},
transcript::{
Direction, Idx, PartialTranscript, Transcript, TranscriptCommitConfig,
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
Direction, PartialTranscript, Transcript, TranscriptCommitConfig, TranscriptCommitRequest,
TranscriptCommitment, TranscriptSecret,
},
};
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
transcript: Option<PartialTranscript>,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder {
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
@@ -41,9 +44,9 @@ impl ProveConfig {
self.server_identity
}
/// Returns the transcript to be proven.
pub fn transcript(&self) -> Option<&PartialTranscript> {
self.transcript.as_ref()
/// Returns the ranges of the transcript to be revealed.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
@@ -57,8 +60,7 @@ impl ProveConfig {
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal_sent: Idx,
reveal_recv: Idx,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
@@ -68,8 +70,7 @@ impl<'a> ProveConfigBuilder<'a> {
Self {
transcript,
server_identity: false,
reveal_sent: Idx::default(),
reveal_recv: Idx::default(),
reveal: None,
transcript_commit: None,
}
}
@@ -92,22 +93,24 @@ impl<'a> ProveConfigBuilder<'a> {
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
let idx = Idx::new(ranges.to_range_set());
let idx = ranges.to_range_set();
if idx.end() > self.transcript.len_of_direction(direction) {
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigBuilderError(
ProveConfigBuilderErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end(),
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
},
));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => self.reveal_sent.union_mut(&idx),
Direction::Received => self.reveal_recv.union_mut(&idx),
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
@@ -127,20 +130,20 @@ impl<'a> ProveConfigBuilder<'a> {
self.reveal(Direction::Received, ranges)
}
/// Reveals the full transcript range for a given direction.
pub fn reveal_all(
&mut self,
direction: Direction,
) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(direction);
self.reveal(direction, &(0..len))
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
let transcript = if !self.reveal_sent.is_empty() || !self.reveal_recv.is_empty() {
Some(
self.transcript
.to_partial(self.reveal_sent, self.reveal_recv),
)
} else {
None
};
Ok(ProveConfig {
server_identity: self.server_identity,
transcript,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
@@ -162,7 +165,7 @@ enum ProveConfigBuilderErrorRepr {
}
/// Configuration to verify information from the prover.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct VerifyConfig {}
impl VerifyConfig {
@@ -196,19 +199,43 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier.
/// Request to prove statements about the connection.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
/// Server identity data.
pub server_identity: Option<(ServerName, ServerCertData)>,
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Transcript commitment configuration.
pub transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Creates a new prove payload.
///
/// # Arguments
///
/// * `config` - The prove config.
/// * `transcript` - The partial transcript.
/// * `handshake` - The server name and handshake data.
pub fn new(
config: &ProveConfig,
transcript: Option<PartialTranscript>,
handshake: Option<(ServerName, HandshakeData)>,
) -> Self {
let transcript_commit = config.transcript_commit().map(|config| config.to_request());
Self {
handshake,
transcript,
transcript_commit,
}
}
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {
/// Transcript commitments.
pub transcript_commitments: Vec<TranscriptCommitment>,
@@ -219,6 +246,7 @@ pub struct ProverOutput {
opaque_debug::implement!(ProverOutput);
/// Verifier output.
#[derive(Serialize, Deserialize)]
pub struct VerifierOutput {
/// Server identity.
pub server_name: Option<ServerName>,

View File

@@ -26,7 +26,7 @@ mod tls;
use std::{fmt, ops::Range};
use rangeset::{Difference, IndexRanges, RangeSet, Subset, ToRangeSet, Union, UnionMut};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -39,6 +39,7 @@ pub use proof::{
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
};
pub use tls::{Record, TlsTranscript};
pub use tls_core::msgs::enums::ContentType;
/// A transcript contains the plaintext of all application data communicated
/// between the Prover and the Server.
@@ -95,18 +96,18 @@ impl Transcript {
/// Returns the subsequence of the transcript with the provided index,
/// returning `None` if the index is out of bounds.
pub fn get(&self, direction: Direction, idx: &Idx) -> Option<Subsequence> {
pub fn get(&self, direction: Direction, idx: &RangeSet<usize>) -> Option<Subsequence> {
let data = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.received,
};
if idx.end() > data.len() {
if idx.end().unwrap_or(0) > data.len() {
return None;
}
Some(
Subsequence::new(idx.clone(), data.index_ranges(&idx.0))
Subsequence::new(idx.clone(), data.index_ranges(idx))
.expect("data is same length as index"),
)
}
@@ -121,7 +122,11 @@ impl Transcript {
///
/// * `sent_idx` - The indices of the sent data to include.
/// * `recv_idx` - The indices of the received data to include.
pub fn to_partial(&self, sent_idx: Idx, recv_idx: Idx) -> PartialTranscript {
pub fn to_partial(
&self,
sent_idx: RangeSet<usize>,
recv_idx: RangeSet<usize>,
) -> PartialTranscript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
@@ -156,9 +161,9 @@ pub struct PartialTranscript {
/// Data received by the Prover from the Server.
received: Vec<u8>,
/// Index of `sent` which have been authenticated.
sent_authed_idx: Idx,
sent_authed_idx: RangeSet<usize>,
/// Index of `received` which have been authenticated.
received_authed_idx: Idx,
received_authed_idx: RangeSet<usize>,
}
/// `PartialTranscript` in a compressed form.
@@ -170,9 +175,9 @@ pub struct CompressedPartialTranscript {
/// Received data which has been authenticated.
received_authed: Vec<u8>,
/// Index of `sent_authed`.
sent_idx: Idx,
sent_idx: RangeSet<usize>,
/// Index of `received_authed`.
recv_idx: Idx,
recv_idx: RangeSet<usize>,
/// Total bytelength of sent data in the original partial transcript.
sent_total: usize,
/// Total bytelength of received data in the original partial transcript.
@@ -184,10 +189,10 @@ impl From<PartialTranscript> for CompressedPartialTranscript {
Self {
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx.0),
.index_ranges(&uncompressed.sent_authed_idx),
received_authed: uncompressed
.received
.index_ranges(&uncompressed.received_authed_idx.0),
.index_ranges(&uncompressed.received_authed_idx),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -237,8 +242,8 @@ impl PartialTranscript {
Self {
sent: vec![0; sent_len],
received: vec![0; received_len],
sent_authed_idx: Idx::default(),
received_authed_idx: Idx::default(),
sent_authed_idx: RangeSet::default(),
received_authed_idx: RangeSet::default(),
}
}
@@ -259,10 +264,10 @@ impl PartialTranscript {
}
/// Returns whether the index is in bounds of the transcript.
pub fn contains(&self, direction: Direction, idx: &Idx) -> bool {
pub fn contains(&self, direction: Direction, idx: &RangeSet<usize>) -> bool {
match direction {
Direction::Sent => idx.end() <= self.sent.len(),
Direction::Received => idx.end() <= self.received.len(),
Direction::Sent => idx.end().unwrap_or(0) <= self.sent.len(),
Direction::Received => idx.end().unwrap_or(0) <= self.received.len(),
}
}
@@ -289,23 +294,23 @@ impl PartialTranscript {
}
/// Returns the index of sent data which have been authenticated.
pub fn sent_authed(&self) -> &Idx {
pub fn sent_authed(&self) -> &RangeSet<usize> {
&self.sent_authed_idx
}
/// Returns the index of received data which have been authenticated.
pub fn received_authed(&self) -> &Idx {
pub fn received_authed(&self) -> &RangeSet<usize> {
&self.received_authed_idx
}
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> Idx {
Idx(RangeSet::from(0..self.sent.len()).difference(&self.sent_authed_idx.0))
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len()).difference(&self.sent_authed_idx)
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> Idx {
Idx(RangeSet::from(0..self.received.len()).difference(&self.received_authed_idx.0))
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len()).difference(&self.received_authed_idx)
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -315,7 +320,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.0.iter().map(|i| data[i])
authed.iter().map(|i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -337,8 +342,7 @@ impl PartialTranscript {
for range in other
.sent_authed_idx
.0
.difference(&self.sent_authed_idx.0)
.difference(&self.sent_authed_idx)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
@@ -346,8 +350,7 @@ impl PartialTranscript {
for range in other
.received_authed_idx
.0
.difference(&self.received_authed_idx.0)
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
@@ -399,12 +402,12 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for range in range.difference(&self.sent_authed_idx.0).iter_ranges() {
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
}
}
Direction::Received => {
for range in range.difference(&self.received_authed_idx.0).iter_ranges() {
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
}
}
@@ -433,130 +436,19 @@ impl fmt::Display for Direction {
}
}
/// Transcript index.
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Idx(RangeSet<usize>);
impl Idx {
/// Creates a new index builder.
pub fn builder() -> IdxBuilder {
IdxBuilder::default()
}
/// Creates an empty index.
pub fn empty() -> Self {
Self(RangeSet::default())
}
/// Creates a new transcript index.
pub fn new(ranges: impl Into<RangeSet<usize>>) -> Self {
Self(ranges.into())
}
/// Returns the start of the index.
pub fn start(&self) -> usize {
self.0.min().unwrap_or_default()
}
/// Returns the end of the index, non-inclusive.
pub fn end(&self) -> usize {
self.0.end().unwrap_or_default()
}
/// Returns an iterator over the values in the index.
pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
self.0.iter()
}
/// Returns an iterator over the ranges of the index.
pub fn iter_ranges(&self) -> impl Iterator<Item = Range<usize>> + '_ {
self.0.iter_ranges()
}
/// Returns the number of values in the index.
pub fn len(&self) -> usize {
self.0.len()
}
/// Returns whether the index is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Returns the number of disjoint ranges in the index.
pub fn count(&self) -> usize {
self.0.len_ranges()
}
pub(crate) fn as_range_set(&self) -> &RangeSet<usize> {
&self.0
}
/// Returns the union of this index with another.
pub(crate) fn union(&self, other: &Idx) -> Idx {
Idx(self.0.union(&other.0))
}
/// Unions this index with another.
pub(crate) fn union_mut(&mut self, other: &Idx) {
self.0.union_mut(&other.0);
}
/// Returns the difference between `self` and `other`.
pub(crate) fn difference(&self, other: &Idx) -> Idx {
Idx(self.0.difference(&other.0))
}
/// Returns `true` if `self` is a subset of `other`.
pub(crate) fn is_subset(&self, other: &Idx) -> bool {
self.0.is_subset(&other.0)
}
}
impl std::fmt::Display for Idx {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Idx([")?;
let count = self.0.len_ranges();
for (i, range) in self.0.iter_ranges().enumerate() {
write!(f, "{}..{}", range.start, range.end)?;
if i < count - 1 {
write!(f, ", ")?;
}
}
f.write_str("])")?;
Ok(())
}
}
/// Builder for [`Idx`].
#[derive(Debug, Default)]
pub struct IdxBuilder(RangeSet<usize>);
impl IdxBuilder {
/// Unions ranges.
pub fn union(self, ranges: &dyn ToRangeSet<usize>) -> Self {
IdxBuilder(self.0.union(&ranges.to_range_set()))
}
/// Builds the index.
pub fn build(self) -> Idx {
Idx(self.0)
}
}
/// Transcript subsequence.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "validation::SubsequenceUnchecked")]
pub struct Subsequence {
/// Index of the subsequence.
idx: Idx,
idx: RangeSet<usize>,
/// Data of the subsequence.
data: Vec<u8>,
}
impl Subsequence {
/// Creates a new subsequence.
pub fn new(idx: Idx, data: Vec<u8>) -> Result<Self, InvalidSubsequence> {
pub fn new(idx: RangeSet<usize>, data: Vec<u8>) -> Result<Self, InvalidSubsequence> {
if idx.len() != data.len() {
return Err(InvalidSubsequence(
"index length does not match data length",
@@ -567,7 +459,7 @@ impl Subsequence {
}
/// Returns the index of the subsequence.
pub fn index(&self) -> &Idx {
pub fn index(&self) -> &RangeSet<usize> {
&self.idx
}
@@ -583,7 +475,7 @@ impl Subsequence {
}
/// Returns the inner parts of the subsequence.
pub fn into_parts(self) -> (Idx, Vec<u8>) {
pub fn into_parts(self) -> (RangeSet<usize>, Vec<u8>) {
(self.idx, self.data)
}
@@ -611,7 +503,7 @@ mod validation {
#[derive(Debug, Deserialize)]
pub(super) struct SubsequenceUnchecked {
idx: Idx,
idx: RangeSet<usize>,
data: Vec<u8>,
}
@@ -633,8 +525,8 @@ mod validation {
pub(super) struct CompressedPartialTranscriptUnchecked {
sent_authed: Vec<u8>,
received_authed: Vec<u8>,
sent_idx: Idx,
recv_idx: Idx,
sent_idx: RangeSet<usize>,
recv_idx: RangeSet<usize>,
sent_total: usize,
recv_total: usize,
}
@@ -651,8 +543,8 @@ mod validation {
));
}
if unchecked.sent_idx.end() > unchecked.sent_total
|| unchecked.recv_idx.end() > unchecked.recv_total
if unchecked.sent_idx.end().unwrap_or(0) > unchecked.sent_total
|| unchecked.recv_idx.end().unwrap_or(0) > unchecked.recv_total
{
return Err(InvalidCompressedPartialTranscript(
"ranges are not in bounds of the data",
@@ -681,8 +573,8 @@ mod validation {
CompressedPartialTranscriptUnchecked {
received_authed: vec![1, 2, 3, 11, 12, 13],
sent_authed: vec![4, 5, 6, 14, 15, 16],
recv_idx: Idx(RangeSet::new(&[1..4, 11..14])),
sent_idx: Idx(RangeSet::new(&[4..7, 14..17])),
recv_idx: RangeSet::from([1..4, 11..14]),
sent_idx: RangeSet::from([4..7, 14..17]),
sent_total: 20,
recv_total: 20,
}
@@ -721,7 +613,6 @@ mod validation {
// Change the total to be less than the last range's end bound.
let end = partial_transcript
.sent_idx
.0
.iter_ranges()
.next_back()
.unwrap()
@@ -753,31 +644,25 @@ mod tests {
#[fixture]
fn partial_transcript() -> PartialTranscript {
transcript().to_partial(
Idx::new(RangeSet::new(&[1..4, 6..9])),
Idx::new(RangeSet::new(&[2..5, 7..10])),
)
transcript().to_partial(RangeSet::from([1..4, 6..9]), RangeSet::from([2..5, 7..10]))
}
#[rstest]
fn test_transcript_get_subsequence(transcript: Transcript) {
let subseq = transcript
.get(Direction::Received, &Idx(RangeSet::from([0..4, 7..10])))
.get(Direction::Received, &RangeSet::from([0..4, 7..10]))
.unwrap();
assert_eq!(subseq.data, vec![0, 1, 2, 3, 7, 8, 9]);
let subseq = transcript
.get(Direction::Sent, &Idx(RangeSet::from([0..4, 9..12])))
.get(Direction::Sent, &RangeSet::from([0..4, 9..12]))
.unwrap();
assert_eq!(subseq.data, vec![0, 1, 2, 3, 9, 10, 11]);
let subseq = transcript.get(
Direction::Received,
&Idx(RangeSet::from([0..4, 7..10, 11..13])),
);
let subseq = transcript.get(Direction::Received, &RangeSet::from([0..4, 7..10, 11..13]));
assert_eq!(subseq, None);
let subseq = transcript.get(Direction::Sent, &Idx(RangeSet::from([0..4, 7..10, 11..13])));
let subseq = transcript.get(Direction::Sent, &RangeSet::from([0..4, 7..10, 11..13]));
assert_eq!(subseq, None);
}
@@ -790,7 +675,7 @@ mod tests {
#[rstest]
fn test_transcript_to_partial_success(transcript: Transcript) {
let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));
let partial = transcript.to_partial(RangeSet::from(0..2), RangeSet::from(3..7));
assert_eq!(partial.sent_unsafe(), [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(
partial.received_unsafe(),
@@ -801,29 +686,30 @@ mod tests {
#[rstest]
#[should_panic]
fn test_transcript_to_partial_failure(transcript: Transcript) {
let _ = transcript.to_partial(Idx::new(0..14), Idx::new(3..7));
let _ = transcript.to_partial(RangeSet::from(0..14), RangeSet::from(3..7));
}
#[rstest]
fn test_partial_transcript_contains(transcript: Transcript) {
let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));
assert!(partial.contains(Direction::Sent, &Idx::new([0..5, 7..10])));
assert!(!partial.contains(Direction::Received, &Idx::new([4..6, 7..13])))
let partial = transcript.to_partial(RangeSet::from(0..2), RangeSet::from(3..7));
assert!(partial.contains(Direction::Sent, &RangeSet::from([0..5, 7..10])));
assert!(!partial.contains(Direction::Received, &RangeSet::from([4..6, 7..13])))
}
#[rstest]
fn test_partial_transcript_unauthed(transcript: Transcript) {
let partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));
assert_eq!(partial.sent_unauthed(), Idx::new(2..12));
assert_eq!(partial.received_unauthed(), Idx::new([0..3, 7..12]));
let partial = transcript.to_partial(RangeSet::from(0..2), RangeSet::from(3..7));
assert_eq!(partial.sent_unauthed(), RangeSet::from(2..12));
assert_eq!(partial.received_unauthed(), RangeSet::from([0..3, 7..12]));
}
#[rstest]
fn test_partial_transcript_union_success(transcript: Transcript) {
// Non overlapping ranges.
let mut simple_partial = transcript.to_partial(Idx::new(0..2), Idx::new(3..7));
let mut simple_partial = transcript.to_partial(RangeSet::from(0..2), RangeSet::from(3..7));
let other_simple_partial = transcript.to_partial(Idx::new(3..5), Idx::new(1..2));
let other_simple_partial =
transcript.to_partial(RangeSet::from(3..5), RangeSet::from(1..2));
simple_partial.union_transcript(&other_simple_partial);
@@ -835,12 +721,16 @@ mod tests {
simple_partial.received_unsafe(),
[0, 1, 0, 3, 4, 5, 6, 0, 0, 0, 0, 0]
);
assert_eq!(simple_partial.sent_authed(), &Idx::new([0..2, 3..5]));
assert_eq!(simple_partial.received_authed(), &Idx::new([1..2, 3..7]));
assert_eq!(simple_partial.sent_authed(), &RangeSet::from([0..2, 3..5]));
assert_eq!(
simple_partial.received_authed(),
&RangeSet::from([1..2, 3..7])
);
// Overwrite with another partial transcript.
let another_simple_partial = transcript.to_partial(Idx::new(1..4), Idx::new(6..9));
let another_simple_partial =
transcript.to_partial(RangeSet::from(1..4), RangeSet::from(6..9));
simple_partial.union_transcript(&another_simple_partial);
@@ -852,13 +742,17 @@ mod tests {
simple_partial.received_unsafe(),
[0, 1, 0, 3, 4, 5, 6, 7, 8, 0, 0, 0]
);
assert_eq!(simple_partial.sent_authed(), &Idx::new(0..5));
assert_eq!(simple_partial.received_authed(), &Idx::new([1..2, 3..9]));
assert_eq!(simple_partial.sent_authed(), &RangeSet::from(0..5));
assert_eq!(
simple_partial.received_authed(),
&RangeSet::from([1..2, 3..9])
);
// Overlapping ranges.
let mut overlap_partial = transcript.to_partial(Idx::new(4..6), Idx::new(3..7));
let mut overlap_partial = transcript.to_partial(RangeSet::from(4..6), RangeSet::from(3..7));
let other_overlap_partial = transcript.to_partial(Idx::new(3..5), Idx::new(5..9));
let other_overlap_partial =
transcript.to_partial(RangeSet::from(3..5), RangeSet::from(5..9));
overlap_partial.union_transcript(&other_overlap_partial);
@@ -870,13 +764,16 @@ mod tests {
overlap_partial.received_unsafe(),
[0, 0, 0, 3, 4, 5, 6, 7, 8, 0, 0, 0]
);
assert_eq!(overlap_partial.sent_authed(), &Idx::new([3..5, 4..6]));
assert_eq!(overlap_partial.received_authed(), &Idx::new([3..7, 5..9]));
assert_eq!(overlap_partial.sent_authed(), &RangeSet::from([3..5, 4..6]));
assert_eq!(
overlap_partial.received_authed(),
&RangeSet::from([3..7, 5..9])
);
// Equal ranges.
let mut equal_partial = transcript.to_partial(Idx::new(4..6), Idx::new(3..7));
let mut equal_partial = transcript.to_partial(RangeSet::from(4..6), RangeSet::from(3..7));
let other_equal_partial = transcript.to_partial(Idx::new(4..6), Idx::new(3..7));
let other_equal_partial = transcript.to_partial(RangeSet::from(4..6), RangeSet::from(3..7));
equal_partial.union_transcript(&other_equal_partial);
@@ -888,13 +785,15 @@ mod tests {
equal_partial.received_unsafe(),
[0, 0, 0, 3, 4, 5, 6, 0, 0, 0, 0, 0]
);
assert_eq!(equal_partial.sent_authed(), &Idx::new(4..6));
assert_eq!(equal_partial.received_authed(), &Idx::new(3..7));
assert_eq!(equal_partial.sent_authed(), &RangeSet::from(4..6));
assert_eq!(equal_partial.received_authed(), &RangeSet::from(3..7));
// Subset ranges.
let mut subset_partial = transcript.to_partial(Idx::new(4..10), Idx::new(3..11));
let mut subset_partial =
transcript.to_partial(RangeSet::from(4..10), RangeSet::from(3..11));
let other_subset_partial = transcript.to_partial(Idx::new(6..9), Idx::new(5..6));
let other_subset_partial =
transcript.to_partial(RangeSet::from(6..9), RangeSet::from(5..6));
subset_partial.union_transcript(&other_subset_partial);
@@ -906,30 +805,32 @@ mod tests {
subset_partial.received_unsafe(),
[0, 0, 0, 3, 4, 5, 6, 7, 8, 9, 10, 0]
);
assert_eq!(subset_partial.sent_authed(), &Idx::new(4..10));
assert_eq!(subset_partial.received_authed(), &Idx::new(3..11));
assert_eq!(subset_partial.sent_authed(), &RangeSet::from(4..10));
assert_eq!(subset_partial.received_authed(), &RangeSet::from(3..11));
}
#[rstest]
#[should_panic]
fn test_partial_transcript_union_failure(transcript: Transcript) {
let mut partial = transcript.to_partial(Idx::new(4..10), Idx::new(3..11));
let mut partial = transcript.to_partial(RangeSet::from(4..10), RangeSet::from(3..11));
let other_transcript = Transcript::new(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
);
let other_partial = other_transcript.to_partial(Idx::new(6..9), Idx::new(5..6));
let other_partial = other_transcript.to_partial(RangeSet::from(6..9), RangeSet::from(5..6));
partial.union_transcript(&other_partial);
}
#[rstest]
fn test_partial_transcript_union_subseq_success(transcript: Transcript) {
let mut partial = transcript.to_partial(Idx::new(4..10), Idx::new(3..11));
let sent_seq = Subsequence::new(Idx::new([0..3, 5..7]), [0, 1, 2, 5, 6].into()).unwrap();
let recv_seq = Subsequence::new(Idx::new([0..4, 5..7]), [0, 1, 2, 3, 5, 6].into()).unwrap();
let mut partial = transcript.to_partial(RangeSet::from(4..10), RangeSet::from(3..11));
let sent_seq =
Subsequence::new(RangeSet::from([0..3, 5..7]), [0, 1, 2, 5, 6].into()).unwrap();
let recv_seq =
Subsequence::new(RangeSet::from([0..4, 5..7]), [0, 1, 2, 3, 5, 6].into()).unwrap();
partial.union_subsequence(Direction::Sent, &sent_seq);
partial.union_subsequence(Direction::Received, &recv_seq);
@@ -939,30 +840,31 @@ mod tests {
partial.received_unsafe(),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0]
);
assert_eq!(partial.sent_authed(), &Idx::new([0..3, 4..10]));
assert_eq!(partial.received_authed(), &Idx::new(0..11));
assert_eq!(partial.sent_authed(), &RangeSet::from([0..3, 4..10]));
assert_eq!(partial.received_authed(), &RangeSet::from(0..11));
// Overwrite with another subseq.
let other_sent_seq = Subsequence::new(Idx::new(0..3), [3, 2, 1].into()).unwrap();
let other_sent_seq = Subsequence::new(RangeSet::from(0..3), [3, 2, 1].into()).unwrap();
partial.union_subsequence(Direction::Sent, &other_sent_seq);
assert_eq!(partial.sent_unsafe(), [3, 2, 1, 0, 4, 5, 6, 7, 8, 9, 0, 0]);
assert_eq!(partial.sent_authed(), &Idx::new([0..3, 4..10]));
assert_eq!(partial.sent_authed(), &RangeSet::from([0..3, 4..10]));
}
#[rstest]
#[should_panic]
fn test_partial_transcript_union_subseq_failure(transcript: Transcript) {
let mut partial = transcript.to_partial(Idx::new(4..10), Idx::new(3..11));
let mut partial = transcript.to_partial(RangeSet::from(4..10), RangeSet::from(3..11));
let sent_seq = Subsequence::new(Idx::new([0..3, 13..15]), [0, 1, 2, 5, 6].into()).unwrap();
let sent_seq =
Subsequence::new(RangeSet::from([0..3, 13..15]), [0, 1, 2, 5, 6].into()).unwrap();
partial.union_subsequence(Direction::Sent, &sent_seq);
}
#[rstest]
fn test_partial_transcript_set_unauthed_range(transcript: Transcript) {
let mut partial = transcript.to_partial(Idx::new(4..10), Idx::new(3..7));
let mut partial = transcript.to_partial(RangeSet::from(4..10), RangeSet::from(3..7));
partial.set_unauthed_range(7, Direction::Sent, 2..5);
partial.set_unauthed_range(5, Direction::Sent, 0..2);
@@ -979,13 +881,13 @@ mod tests {
#[rstest]
#[should_panic]
fn test_subsequence_new_invalid_len() {
let _ = Subsequence::new(Idx::new([0..3, 5..8]), [0, 1, 2, 5, 6].into()).unwrap();
let _ = Subsequence::new(RangeSet::from([0..3, 5..8]), [0, 1, 2, 5, 6].into()).unwrap();
}
#[rstest]
#[should_panic]
fn test_subsequence_copy_to_invalid_len() {
let seq = Subsequence::new(Idx::new([0..3, 5..7]), [0, 1, 2, 5, 6].into()).unwrap();
let seq = Subsequence::new(RangeSet::from([0..3, 5..7]), [0, 1, 2, 5, 6].into()).unwrap();
let mut data: [u8; 3] = [0, 1, 2];
seq.copy_to(&mut data);

View File

@@ -1,6 +1,6 @@
//! Transcript commitments.
use std::{collections::HashSet, fmt};
use std::fmt;
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
@@ -10,7 +10,7 @@ use crate::{
transcript::{
encoding::{EncodingCommitment, EncodingTree},
hash::{PlaintextHash, PlaintextHashSecret},
Direction, Idx, Transcript,
Direction, RangeSet, Transcript,
},
};
@@ -66,17 +66,15 @@ pub enum TranscriptSecret {
}
/// Configuration for transcript commitments.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, Idx), TranscriptCommitmentKind)>,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
impl TranscriptCommitConfig {
/// Creates a new commit config builder.
pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder {
pub fn builder(transcript: &Transcript) -> TranscriptCommitConfigBuilder<'_> {
TranscriptCommitConfigBuilder::new(transcript)
}
@@ -85,18 +83,8 @@ impl TranscriptCommitConfig {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, Idx)> {
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Encoding => Some(idx),
_ => None,
@@ -104,7 +92,7 @@ impl TranscriptCommitConfig {
}
/// Returns an iterator over the hash commitment indices.
pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, Idx), &HashAlgId)> {
pub fn iter_hash(&self) -> impl Iterator<Item = (&(Direction, RangeSet<usize>), &HashAlgId)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
TranscriptCommitmentKind::Hash { alg } => Some((idx, alg)),
_ => None,
@@ -114,7 +102,10 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding,
encoding: self
.iter_encoding()
.map(|(dir, idx)| (*dir, idx.clone()))
.collect(),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -131,10 +122,8 @@ impl TranscriptCommitConfig {
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, Idx), TranscriptCommitmentKind)>,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
impl<'a> TranscriptCommitConfigBuilder<'a> {
@@ -143,10 +132,8 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
Self {
transcript,
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
commits: Vec::default(),
}
}
@@ -175,27 +162,25 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
direction: Direction,
kind: TranscriptCommitmentKind,
) -> Result<&mut Self, TranscriptCommitConfigBuilderError> {
let idx = Idx::new(ranges.to_range_set());
let idx = ranges.to_range_set();
if idx.end() > self.transcript.len_of_direction(direction) {
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(TranscriptCommitConfigBuilderError::new(
ErrorKind::Index,
format!(
"range is out of bounds of the transcript ({}): {} > {}",
direction,
idx.end(),
idx.end().unwrap_or(0),
self.transcript.len_of_direction(direction)
),
));
}
let value = ((direction, idx), kind);
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
if !self.commits.contains(&value) {
self.commits.push(value);
}
self.commits.insert(((direction, idx), kind));
Ok(self)
}
@@ -241,8 +226,6 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -289,23 +272,18 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: bool,
hash: Vec<(Direction, Idx, HashAlgId)>,
encoding: Vec<(Direction, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool {
self.encoding
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()
/// Returns an iterator over the encoding commitments.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.encoding.iter()
}
/// Returns an iterator over the hash commitments.
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, Idx, HashAlgId)> {
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
self.hash.iter()
}
}

View File

@@ -12,7 +12,7 @@ const BIT_ENCODING_SIZE: usize = 16;
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],

View File

@@ -9,7 +9,7 @@ use crate::{
transcript::{
commit::MAX_TOTAL_COMMITTED_DATA,
encoding::{new_encoder, Encoder, EncodingCommitment},
Direction, Idx,
Direction,
},
};
@@ -17,7 +17,7 @@ use crate::{
#[derive(Clone, Serialize, Deserialize)]
pub(super) struct Opening {
pub(super) direction: Direction,
pub(super) idx: Idx,
pub(super) idx: RangeSet<usize>,
pub(super) blinder: Blinder,
}
@@ -51,7 +51,7 @@ impl EncodingProof {
commitment: &EncodingCommitment,
sent: &[u8],
recv: &[u8],
) -> Result<(Idx, Idx), EncodingProofError> {
) -> Result<(RangeSet<usize>, RangeSet<usize>), EncodingProofError> {
let hasher = provider.get(&commitment.root.alg)?;
let encoder = new_encoder(&commitment.secret);
@@ -89,13 +89,13 @@ impl EncodingProof {
};
// Make sure the ranges are within the bounds of the transcript.
if idx.end() > data.len() {
if idx.end().unwrap_or(0) > data.len() {
return Err(EncodingProofError::new(
ErrorKind::Proof,
format!(
"index out of bounds of the transcript ({}): {} > {}",
direction,
idx.end(),
idx.end().unwrap_or(0),
data.len()
),
));
@@ -111,7 +111,7 @@ impl EncodingProof {
// present in the merkle tree.
leaves.push((*id, hasher.hash(&expected_leaf)));
auth.union_mut(idx.as_range_set());
auth.union_mut(idx);
}
// Verify that the expected hashes are present in the merkle tree.
@@ -121,7 +121,7 @@ impl EncodingProof {
// data is authentic.
inclusion_proof.verify(hasher, &commitment.root, leaves)?;
Ok((Idx(auth_sent), Idx(auth_recv)))
Ok((auth_sent, auth_recv))
}
}
@@ -234,7 +234,7 @@ mod test {
hash::Blake3,
transcript::{
encoding::{EncoderSecret, EncodingTree},
Idx, Transcript,
Transcript,
},
};
@@ -249,8 +249,8 @@ mod test {
fn new_encoding_fixture(secret: EncoderSecret) -> EncodingFixture {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len()));
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let provider = encoding_provider(transcript.sent(), transcript.received());
let tree = EncodingTree::new(&Blake3::default(), [&idx_0, &idx_1], &provider).unwrap();
@@ -317,7 +317,7 @@ mod test {
let Opening { idx, .. } = proof.openings.values_mut().next().unwrap();
*idx = Idx::new([0..3, 13..15]);
*idx = RangeSet::from([0..3, 13..15]);
let err = proof
.verify_with_provider(

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
@@ -11,7 +12,7 @@ use crate::{
proof::{EncodingProof, Opening},
EncodingProvider,
},
Direction, Idx,
Direction,
},
};
@@ -22,7 +23,7 @@ pub enum EncodingTreeError {
#[error("index is out of bounds of the transcript")]
OutOfBounds {
/// The index.
index: Idx,
index: RangeSet<usize>,
/// The transcript length.
transcript_length: usize,
},
@@ -30,13 +31,13 @@ pub enum EncodingTreeError {
#[error("encoding provider is missing an encoding for an index")]
MissingEncoding {
/// The index which is missing.
index: Idx,
index: RangeSet<usize>,
},
/// Index is missing from the tree.
#[error("index is missing from the tree")]
MissingLeaf {
/// The index which is missing.
index: Idx,
index: RangeSet<usize>,
},
}
@@ -49,11 +50,11 @@ pub struct EncodingTree {
blinders: Vec<Blinder>,
/// Mapping between the index of a leaf and the transcript index it
/// corresponds to.
idxs: BiMap<usize, (Direction, Idx)>,
idxs: BiMap<usize, (Direction, RangeSet<usize>)>,
/// Union of all transcript indices in the sent direction.
sent_idx: Idx,
sent_idx: RangeSet<usize>,
/// Union of all transcript indices in the received direction.
received_idx: Idx,
received_idx: RangeSet<usize>,
}
opaque_debug::implement!(EncodingTree);
@@ -68,15 +69,15 @@ impl EncodingTree {
/// * `provider` - The encoding provider.
pub fn new<'idx>(
hasher: &dyn HashAlgorithm,
idxs: impl IntoIterator<Item = &'idx (Direction, Idx)>,
idxs: impl IntoIterator<Item = &'idx (Direction, RangeSet<usize>)>,
provider: &dyn EncodingProvider,
) -> Result<Self, EncodingTreeError> {
let mut this = Self {
tree: MerkleTree::new(hasher.id()),
blinders: Vec::new(),
idxs: BiMap::new(),
sent_idx: Idx::empty(),
received_idx: Idx::empty(),
sent_idx: RangeSet::default(),
received_idx: RangeSet::default(),
};
let mut leaves = Vec::new();
@@ -138,7 +139,7 @@ impl EncodingTree {
/// * `idxs` - The transcript indices to prove.
pub fn proof<'idx>(
&self,
idxs: impl Iterator<Item = &'idx (Direction, Idx)>,
idxs: impl Iterator<Item = &'idx (Direction, RangeSet<usize>)>,
) -> Result<EncodingProof, EncodingTreeError> {
let mut openings = HashMap::new();
for dir_idx in idxs {
@@ -171,11 +172,11 @@ impl EncodingTree {
}
/// Returns whether the tree contains the given transcript index.
pub fn contains(&self, idx: &(Direction, Idx)) -> bool {
pub fn contains(&self, idx: &(Direction, RangeSet<usize>)) -> bool {
self.idxs.contains_right(idx)
}
pub(crate) fn idx(&self, direction: Direction) -> &Idx {
pub(crate) fn idx(&self, direction: Direction) -> &RangeSet<usize> {
match direction {
Direction::Sent => &self.sent_idx,
Direction::Received => &self.received_idx,
@@ -183,7 +184,7 @@ impl EncodingTree {
}
/// Returns the committed transcript indices.
pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, Idx)> {
pub(crate) fn transcript_indices(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.idxs.right_values()
}
}
@@ -200,7 +201,7 @@ mod tests {
fn new_tree<'seq>(
transcript: &Transcript,
idxs: impl Iterator<Item = &'seq (Direction, Idx)>,
idxs: impl Iterator<Item = &'seq (Direction, RangeSet<usize>)>,
) -> Result<EncodingTree, EncodingTreeError> {
let provider = encoding_provider(transcript.sent(), transcript.received());
@@ -211,8 +212,8 @@ mod tests {
fn test_encoding_tree() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len()));
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
@@ -243,10 +244,10 @@ mod tests {
fn test_encoding_tree_multiple_ranges() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..1));
let idx_1 = (Direction::Sent, Idx::new(1..POST_JSON.len()));
let idx_2 = (Direction::Received, Idx::new(0..1));
let idx_3 = (Direction::Received, Idx::new(1..OK_JSON.len()));
let idx_0 = (Direction::Sent, RangeSet::from(0..1));
let idx_1 = (Direction::Sent, RangeSet::from(1..POST_JSON.len()));
let idx_2 = (Direction::Received, RangeSet::from(0..1));
let idx_3 = (Direction::Received, RangeSet::from(1..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1, &idx_2, &idx_3].into_iter()).unwrap();
@@ -273,11 +274,11 @@ mod tests {
)
.unwrap();
let mut expected_auth_sent = Idx::default();
let mut expected_auth_sent = RangeSet::default();
expected_auth_sent.union_mut(&idx_0.1);
expected_auth_sent.union_mut(&idx_1.1);
let mut expected_auth_recv = Idx::default();
let mut expected_auth_recv = RangeSet::default();
expected_auth_recv.union_mut(&idx_2.1);
expected_auth_recv.union_mut(&idx_3.1);
@@ -289,9 +290,9 @@ mod tests {
fn test_encoding_tree_proof_missing_leaf() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len()));
let idx_1 = (Direction::Received, Idx::new(0..4));
let idx_2 = (Direction::Received, Idx::new(4..OK_JSON.len()));
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
let idx_1 = (Direction::Received, RangeSet::from(0..4));
let idx_2 = (Direction::Received, RangeSet::from(4..OK_JSON.len()));
let tree = new_tree(&transcript, [&idx_0, &idx_1].into_iter()).unwrap();
@@ -305,8 +306,8 @@ mod tests {
fn test_encoding_tree_out_of_bounds() {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, Idx::new(0..POST_JSON.len() + 1));
let idx_1 = (Direction::Received, Idx::new(0..OK_JSON.len() + 1));
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len() + 1));
let idx_1 = (Direction::Received, RangeSet::from(0..OK_JSON.len() + 1));
let result = new_tree(&transcript, [&idx_0].into_iter()).unwrap_err();
assert!(matches!(result, EncodingTreeError::MissingEncoding { .. }));
@@ -321,7 +322,7 @@ mod tests {
let result = EncodingTree::new(
&Blake3::default(),
[(Direction::Sent, Idx::new(0..8))].iter(),
[(Direction::Sent, RangeSet::from(0..8))].iter(),
&provider,
)
.unwrap_err();

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
hash::{Blinder, HashAlgId, HashAlgorithm, TypedHash},
transcript::{Direction, Idx},
transcript::{Direction, RangeSet},
};
/// Hashes plaintext with a blinder.
@@ -23,7 +23,7 @@ pub struct PlaintextHash {
/// Direction of the plaintext.
pub direction: Direction,
/// Index of plaintext.
pub idx: Idx,
pub idx: RangeSet<usize>,
/// The hash of the data.
pub hash: TypedHash,
}
@@ -34,7 +34,7 @@ pub struct PlaintextHashSecret {
/// Direction of the plaintext.
pub direction: Direction,
/// Index of plaintext.
pub idx: Idx,
pub idx: RangeSet<usize>,
/// The algorithm of the hash.
pub alg: HashAlgId,
/// Blinder for the hash.

View File

@@ -1,17 +1,18 @@
//! Transcript proofs.
use rangeset::{Cover, ToRangeSet};
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
use crate::{
connection::TranscriptLength,
display::FmtRangeSet,
hash::{HashAlgId, HashProvider},
transcript::{
commit::{TranscriptCommitment, TranscriptCommitmentKind},
encoding::{EncodingProof, EncodingProofError, EncodingTree},
hash::{hash_plaintext, PlaintextHash, PlaintextHashSecret},
Direction, Idx, PartialTranscript, Transcript, TranscriptSecret,
Direction, PartialTranscript, RangeSet, Transcript, TranscriptSecret,
},
};
@@ -77,8 +78,8 @@ impl TranscriptProof {
));
}
let mut total_auth_sent = Idx::default();
let mut total_auth_recv = Idx::default();
let mut total_auth_sent = RangeSet::default();
let mut total_auth_recv = RangeSet::default();
// Verify encoding proof.
if let Some(proof) = self.encoding_proof {
@@ -120,7 +121,7 @@ impl TranscriptProof {
Direction::Received => (self.transcript.received_unsafe(), &mut total_auth_recv),
};
if idx.end() > plaintext.len() {
if idx.end().unwrap_or(0) > plaintext.len() {
return Err(TranscriptProofError::new(
ErrorKind::Hash,
"hash opening index is out of bounds",
@@ -215,15 +216,15 @@ impl From<EncodingProofError> for TranscriptProofError {
/// Union of ranges to reveal.
#[derive(Clone, Debug, PartialEq)]
struct QueryIdx {
sent: Idx,
recv: Idx,
sent: RangeSet<usize>,
recv: RangeSet<usize>,
}
impl QueryIdx {
fn new() -> Self {
Self {
sent: Idx::empty(),
recv: Idx::empty(),
sent: RangeSet::default(),
recv: RangeSet::default(),
}
}
@@ -231,7 +232,7 @@ impl QueryIdx {
self.sent.is_empty() && self.recv.is_empty()
}
fn union(&mut self, direction: &Direction, other: &Idx) {
fn union(&mut self, direction: &Direction, other: &RangeSet<usize>) {
match direction {
Direction::Sent => self.sent.union_mut(other),
Direction::Received => self.recv.union_mut(other),
@@ -241,7 +242,12 @@ impl QueryIdx {
impl std::fmt::Display for QueryIdx {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "sent: {}, received: {}", self.sent, self.recv)
write!(
f,
"sent: {}, received: {}",
FmtRangeSet(&self.sent),
FmtRangeSet(&self.recv)
)
}
}
@@ -253,8 +259,8 @@ pub struct TranscriptProofBuilder<'a> {
transcript: &'a Transcript,
encoding_tree: Option<&'a EncodingTree>,
hash_secrets: Vec<&'a PlaintextHashSecret>,
committed_sent: Idx,
committed_recv: Idx,
committed_sent: RangeSet<usize>,
committed_recv: RangeSet<usize>,
query_idx: QueryIdx,
}
@@ -264,8 +270,8 @@ impl<'a> TranscriptProofBuilder<'a> {
transcript: &'a Transcript,
secrets: impl IntoIterator<Item = &'a TranscriptSecret>,
) -> Self {
let mut committed_sent = Idx::empty();
let mut committed_recv = Idx::empty();
let mut committed_sent = RangeSet::default();
let mut committed_recv = RangeSet::default();
let mut encoding_tree = None;
let mut hash_secrets = Vec::new();
@@ -323,15 +329,15 @@ impl<'a> TranscriptProofBuilder<'a> {
ranges: &dyn ToRangeSet<usize>,
direction: Direction,
) -> Result<&mut Self, TranscriptProofBuilderError> {
let idx = Idx::new(ranges.to_range_set());
let idx = ranges.to_range_set();
if idx.end() > self.transcript.len_of_direction(direction) {
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::Index,
format!(
"range is out of bounds of the transcript ({}): {} > {}",
direction,
idx.end(),
idx.end().unwrap_or(0),
self.transcript.len_of_direction(direction)
),
));
@@ -348,7 +354,10 @@ impl<'a> TranscriptProofBuilder<'a> {
let missing = idx.difference(committed);
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!("commitment is missing for ranges in {direction} transcript: {missing}"),
format!(
"commitment is missing for ranges in {direction} transcript: {}",
FmtRangeSet(&missing)
),
));
}
Ok(self)
@@ -403,25 +412,23 @@ impl<'a> TranscriptProofBuilder<'a> {
continue;
};
let (sent_dir_idxs, sent_uncovered) =
uncovered_query_idx.sent.as_range_set().cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Sent),
|(_, idx)| &idx.0,
);
let (sent_dir_idxs, sent_uncovered) = uncovered_query_idx.sent.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Sent),
|(_, idx)| idx,
);
// Uncovered ranges will be checked with ranges of the next
// preferred commitment kind.
uncovered_query_idx.sent = Idx(sent_uncovered);
uncovered_query_idx.sent = sent_uncovered;
let (recv_dir_idxs, recv_uncovered) =
uncovered_query_idx.recv.as_range_set().cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Received),
|(_, idx)| &idx.0,
);
uncovered_query_idx.recv = Idx(recv_uncovered);
let (recv_dir_idxs, recv_uncovered) = uncovered_query_idx.recv.cover_by(
encoding_tree
.transcript_indices()
.filter(|(dir, _)| *dir == Direction::Received),
|(_, idx)| idx,
);
uncovered_query_idx.recv = recv_uncovered;
let dir_idxs = sent_dir_idxs
.into_iter()
@@ -439,25 +446,23 @@ impl<'a> TranscriptProofBuilder<'a> {
}
}
TranscriptCommitmentKind::Hash { alg } => {
let (sent_hashes, sent_uncovered) =
uncovered_query_idx.sent.as_range_set().cover_by(
self.hash_secrets.iter().filter(|hash| {
hash.direction == Direction::Sent && &hash.alg == alg
}),
|hash| &hash.idx.0,
);
let (sent_hashes, sent_uncovered) = uncovered_query_idx.sent.cover_by(
self.hash_secrets.iter().filter(|hash| {
hash.direction == Direction::Sent && &hash.alg == alg
}),
|hash| &hash.idx,
);
// Uncovered ranges will be checked with ranges of the next
// preferred commitment kind.
uncovered_query_idx.sent = Idx(sent_uncovered);
uncovered_query_idx.sent = sent_uncovered;
let (recv_hashes, recv_uncovered) =
uncovered_query_idx.recv.as_range_set().cover_by(
self.hash_secrets.iter().filter(|hash| {
hash.direction == Direction::Received && &hash.alg == alg
}),
|hash| &hash.idx.0,
);
uncovered_query_idx.recv = Idx(recv_uncovered);
let (recv_hashes, recv_uncovered) = uncovered_query_idx.recv.cover_by(
self.hash_secrets.iter().filter(|hash| {
hash.direction == Direction::Received && &hash.alg == alg
}),
|hash| &hash.idx,
);
uncovered_query_idx.recv = recv_uncovered;
transcript_proof.hash_secrets.extend(
sent_hashes
@@ -577,7 +582,7 @@ mod tests {
#[rstest]
fn test_verify_missing_encoding_commitment_root() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let idxs = vec![(Direction::Received, Idx::new(0..transcript.len().1))];
let idxs = vec![(Direction::Received, RangeSet::from(0..transcript.len().1))];
let encoding_tree = EncodingTree::new(
&Blake3::default(),
&idxs,
@@ -638,7 +643,7 @@ mod tests {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let direction = Direction::Sent;
let idx = Idx::new(0..10);
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
@@ -684,7 +689,7 @@ mod tests {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let direction = Direction::Sent;
let idx = Idx::new(0..10);
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
@@ -894,10 +899,10 @@ mod tests {
match kind {
BuilderErrorKind::Cover { uncovered, .. } => {
if !uncovered_sent_rangeset.is_empty() {
assert_eq!(uncovered.sent, Idx(uncovered_sent_rangeset));
assert_eq!(uncovered.sent, uncovered_sent_rangeset);
}
if !uncovered_recv_rangeset.is_empty() {
assert_eq!(uncovered.recv, Idx(uncovered_recv_rangeset));
assert_eq!(uncovered.recv, uncovered_recv_rangeset);
}
}
_ => panic!("unexpected error kind: {kind:?}"),

View File

@@ -2,10 +2,10 @@
use crate::{
connection::{
Certificate, HandshakeData, HandshakeDataV1_2, ServerEphemKey, ServerSignature, TlsVersion,
VerifyData,
CertBinding, CertBindingV1_2, ServerEphemKey, ServerSignature, TlsVersion, VerifyData,
},
transcript::{Direction, Transcript},
webpki::CertificateDer,
};
use tls_core::msgs::{
alert::AlertMessagePayload,
@@ -19,9 +19,9 @@ use tls_core::msgs::{
pub struct TlsTranscript {
time: u64,
version: TlsVersion,
server_cert_chain: Option<Vec<Certificate>>,
server_cert_chain: Option<Vec<CertificateDer>>,
server_signature: Option<ServerSignature>,
handshake_data: HandshakeData,
certificate_binding: CertBinding,
sent: Vec<Record>,
recv: Vec<Record>,
}
@@ -32,9 +32,9 @@ impl TlsTranscript {
pub fn new(
time: u64,
version: TlsVersion,
server_cert_chain: Option<Vec<Certificate>>,
server_cert_chain: Option<Vec<CertificateDer>>,
server_signature: Option<ServerSignature>,
handshake_data: HandshakeData,
certificate_binding: CertBinding,
verify_data: VerifyData,
sent: Vec<Record>,
recv: Vec<Record>,
@@ -198,7 +198,7 @@ impl TlsTranscript {
version,
server_cert_chain,
server_signature,
handshake_data,
certificate_binding,
sent,
recv,
})
@@ -215,7 +215,7 @@ impl TlsTranscript {
}
/// Returns the server certificate chain.
pub fn server_cert_chain(&self) -> Option<&[Certificate]> {
pub fn server_cert_chain(&self) -> Option<&[CertificateDer]> {
self.server_cert_chain.as_deref()
}
@@ -226,17 +226,17 @@ impl TlsTranscript {
/// Returns the server ephemeral key used in the TLS handshake.
pub fn server_ephemeral_key(&self) -> &ServerEphemKey {
match &self.handshake_data {
HandshakeData::V1_2(HandshakeDataV1_2 {
match &self.certificate_binding {
CertBinding::V1_2(CertBindingV1_2 {
server_ephemeral_key,
..
}) => server_ephemeral_key,
}
}
/// Returns the handshake data.
pub fn handshake_data(&self) -> &HandshakeData {
&self.handshake_data
/// Returns the certificate binding data.
pub fn certificate_binding(&self) -> &CertBinding {
&self.certificate_binding
}
/// Returns the sent records.

168
crates/core/src/webpki.rs Normal file
View File

@@ -0,0 +1,168 @@
//! Web PKI types.
use std::time::Duration;
use rustls_pki_types::{self as webpki_types, pem::PemObject};
use serde::{Deserialize, Serialize};
use crate::connection::ServerName;
/// X.509 certificate, DER encoded.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateDer(pub Vec<u8>);
impl CertificateDer {
/// Creates a DER-encoded certificate from a PEM-encoded certificate.
pub fn from_pem_slice(pem: &[u8]) -> Result<Self, PemError> {
let der = webpki_types::CertificateDer::from_pem_slice(pem).map_err(|_| PemError {})?;
Ok(Self(der.to_vec()))
}
}
/// Private key, DER encoded.
#[derive(Debug, Clone, zeroize::ZeroizeOnDrop, Serialize, Deserialize)]
pub struct PrivateKeyDer(pub Vec<u8>);
impl PrivateKeyDer {
/// Creates a DER-encoded private key from a PEM-encoded private key.
pub fn from_pem_slice(pem: &[u8]) -> Result<Self, PemError> {
let der = webpki_types::PrivateKeyDer::from_pem_slice(pem).map_err(|_| PemError {})?;
Ok(Self(der.secret_der().to_vec()))
}
}
/// PEM parsing error.
#[derive(Debug, thiserror::Error)]
#[error("failed to parse PEM object")]
pub struct PemError {}
/// Root certificate store.
///
/// This stores root certificates which are used to verify end-entity
/// certificates presented by a TLS server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RootCertStore {
/// Unvalidated DER-encoded X.509 root certificates.
pub roots: Vec<CertificateDer>,
}
impl RootCertStore {
/// Creates an empty root certificate store.
pub fn empty() -> Self {
Self { roots: Vec::new() }
}
}
/// Server certificate verifier.
#[derive(Debug)]
pub struct ServerCertVerifier {
roots: Vec<webpki_types::TrustAnchor<'static>>,
}
impl ServerCertVerifier {
/// Creates a new server certificate verifier.
pub fn new(roots: &RootCertStore) -> Result<Self, ServerCertVerifierError> {
let roots = roots
.roots
.iter()
.map(|cert| {
webpki::anchor_from_trusted_cert(&webpki_types::CertificateDer::from(
cert.0.as_slice(),
))
.map(|anchor| anchor.to_owned())
.map_err(|err| ServerCertVerifierError::InvalidRootCertificate {
cert: cert.clone(),
reason: err.to_string(),
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self { roots })
}
/// Creates a new server certificate verifier with Mozilla root
/// certificates.
pub fn mozilla() -> Self {
Self {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
}
/// Verifies the server certificate was valid at the given time of
/// presentation.
///
/// # Arguments
///
/// * `end_entity` - End-entity certificate to verify.
/// * `intermediates` - Intermediate certificates to a trust anchor.
/// * `server_name` - Server DNS name.
/// * `time` - Unix time the certificate was presented.
pub fn verify_server_cert(
&self,
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
server_name: &ServerName,
time: u64,
) -> Result<(), ServerCertVerifierError> {
let cert = webpki_types::CertificateDer::from(end_entity.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(|e| {
ServerCertVerifierError::InvalidEndEntityCertificate {
cert: end_entity.clone(),
reason: e.to_string(),
}
})?;
let intermediates = intermediates
.iter()
.map(|c| webpki_types::CertificateDer::from(c.0.as_slice()))
.collect::<Vec<_>>();
let server_name = server_name.to_webpki();
let time = webpki_types::UnixTime::since_unix_epoch(Duration::from_secs(time));
cert.verify_for_usage(
webpki::ALL_VERIFICATION_ALGS,
&self.roots,
&intermediates,
time,
webpki::KeyUsage::server_auth(),
None,
None,
)
.map(|_| ())
.map_err(|_| ServerCertVerifierError::InvalidPath)?;
cert.verify_is_valid_for_subject_name(&server_name)
.map_err(|_| ServerCertVerifierError::InvalidServerName)?;
Ok(())
}
}
/// Error for [`ServerCertVerifier`].
#[derive(Debug, thiserror::Error)]
#[error("server certificate verification failed: {0}")]
pub enum ServerCertVerifierError {
/// Root certificate store contains invalid certificate.
#[error("root certificate store contains invalid certificate: {reason}")]
InvalidRootCertificate {
/// Invalid certificate.
cert: CertificateDer,
/// Reason for invalidity.
reason: String,
},
/// End-entity certificate is invalid.
#[error("end-entity certificate is invalid: {reason}")]
InvalidEndEntityCertificate {
/// Invalid certificate.
cert: CertificateDer,
/// Reason for invalidity.
reason: String,
},
/// Failed to verify certificate path to provided trust anchors.
#[error("failed to verify certificate path to provided trust anchors")]
InvalidPath,
/// Failed to verify certificate is valid for provided server name.
#[error("failed to verify certificate is valid for provided server name")]
InvalidServerName,
}

View File

@@ -8,10 +8,8 @@ version = "0.0.0"
workspace = true
[dependencies]
tlsn-core = { workspace = true }
tlsn = { workspace = true }
tlsn-formats = { workspace = true }
tlsn-tls-core = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
@@ -20,7 +18,6 @@ spansy = { workspace = true }
bincode = { workspace = true }
chrono = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
dotenv = { version = "0.15.0" }
futures = { workspace = true }
http-body-util = { workspace = true }
hex = { workspace = true }
@@ -43,3 +40,15 @@ tracing-subscriber = { workspace = true }
[[example]]
name = "interactive"
path = "interactive/interactive.rs"
[[example]]
name = "attestation_prove"
path = "attestation/prove.rs"
[[example]]
name = "attestation_present"
path = "attestation/present.rs"
[[example]]
name = "attestation_verify"
path = "attestation/verify.rs"

View File

@@ -5,4 +5,4 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
Refer to <https://docs.tlsnotary.org/quick_start/index.html> for a quick start guide to using TLSNotary with these examples.
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.

View File

@@ -0,0 +1,164 @@
# Attestation Example
This example demonstrates a **TLSNotary attestation workflow**: notarizing data from a server with a trusted third party (Notary), then creating verifiable presentations with selective disclosure of sensitive information to a Verifier.
## 🔍 How It Works
```mermaid
sequenceDiagram
participant P as Prover
participant N as MPC-TLS<br/>Verifier
participant S as Server<br/>Fixture
participant V as Attestation<br/>Verifier
Note over P,S: 1. Notarization Phase
P->>N: Establish MPC-TLS connection
P->>S: Request (MPC-TLS)
S->>P: Response (MPC-TLS)
N->>P: Issue signed attestation
Note over P: 2. Presentation Phase
P->>P: Create redacted presentation
Note over P,V: 3. Verification Phase
P->>V: Share presentation
V->>V: Verify attestation signature
```
### The Three-Step Process
1. **🔐 Notarize**: Prover collaborates with Notary to create an authenticated TLS session and obtain a signed attestation
2. **✂️ Present**: Prover creates a selective presentation, choosing which data to reveal or redact
3. **✅ Verify**: Anyone can verify the presentation's authenticity using the Notary's public key
## 🚀 Quick Start
### Step 1: Notarize Data
**Start the test server** (from repository root):
```bash
RUST_LOG=info PORT=4000 cargo run --bin tlsn-server-fixture
```
**Run the notarization** (in a new terminal):
```bash
RUST_LOG=info SERVER_PORT=4000 cargo run --release --example attestation_prove
```
**Expected output:**
```
Notarization completed successfully!
The attestation has been written to `example-json.attestation.tlsn` and the corresponding secrets to `example-json.secrets.tlsn`.
```
### Step 2: Create Verifiable Presentation
**Generate a redacted presentation:**
```bash
cargo run --release --example attestation_present
```
**Expected output:**
```
Presentation built successfully!
The presentation has been written to `example-json.presentation.tlsn`.
```
> 💡 **Tip**: You can create multiple presentations from the same attestation, each with different redactions!
### Step 3: Verify the Presentation
**Verify the presentation:**
```bash
cargo run --release --example attestation_verify
```
**Expected output:**
```
Verifying presentation with {key algorithm} key: { hex encoded key }
**Ask yourself, do you trust this key?**
-------------------------------------------------------------------
Successfully verified that the data below came from a session with test-server.io at { time }.
Note that the data which the Prover chose not to disclose are shown as X.
Data sent:
GET /formats/json HTTP/1.1
host: test-server.io
accept: */*
accept-encoding: identity
connection: close
user-agent: XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
Data received:
HTTP/1.1 200 OK
content-type: application/json
content-length: 722
connection: close
date: Mon, 08 Sep 2025 09:18:29 GMT
XXXXXX1234567890XXXXXXXXXXXXXXXXXXXXXXXXJohn DoeXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX1.2XX
```
## 🎯 Use Cases & Examples
### JSON Data (Default)
Perfect for API responses, configuration data, or structured information:
```bash
# All three steps use JSON by default
SERVER_PORT=4000 cargo run --release --example attestation_prove
cargo run --release --example attestation_present
cargo run --release --example attestation_verify
```
### HTML Content
Ideal for web pages, forms, or any HTML-based data:
```bash
# Notarize HTML content
SERVER_PORT=4000 cargo run --release --example attestation_prove -- html
cargo run --release --example attestation_present -- html
cargo run --release --example attestation_verify -- html
```
### Authenticated/Private Data
For APIs requiring authentication tokens, cookies, or private access:
```bash
# Notarize private data with authentication
SERVER_PORT=4000 cargo run --release --example attestation_prove -- authenticated
cargo run --release --example attestation_present -- authenticated
cargo run --release --example attestation_verify -- authenticated
```
### Debug Mode
For detailed logging and troubleshooting:
```bash
RUST_LOG=debug,yamux=info,uid_mux=info SERVER_PORT=4000 cargo run --release --example attestation_prove
```
### Generated Files
After running the examples, you'll find:
- **`*.attestation.tlsn`**: The cryptographically signed attestation from the Notary
- **`*.secrets.tlsn`**: Cryptographic secrets needed to create presentations
- **`*.presentation.tlsn`**: The verifiable presentation with your chosen redactions
## 🔐 Security Considerations
### Trust Model
-**Notary Key**: The presentation includes the Notary's verifying key - The verifier must trust this key
-**Data Authenticity**: Cryptographically guaranteed that data came from the specified server
-**Tamper Evidence**: Any modification to the presentation will fail verification
- ⚠️ **Notary Trust**: The verifier must trust the Notary not to collude with the Prover
### Production Deployment
- 🏭 **Independent Notary**: Use a trusted third-party Notary service (not a local one)
- 🔒 **Key Management**: Implement proper Notary key distribution and verification
- 📋 **Audit Trail**: Maintain logs of notarization and verification events
- 🔄 **Key Rotation**: Plan for Notary key updates and migration
> ⚠️ **Demo Notice**: This example uses a local test server and local Notary for demonstration. In production, use trusted third-party Notary services and real server endpoints.

View File

@@ -0,0 +1,117 @@
// This example demonstrates how to build a verifiable presentation from an
// attestation and the corresponding connection secrets. See the `prove.rs`
// example to learn how to acquire an attestation from a Notary.
use clap::Parser;
use hyper::header;
use tlsn::attestation::{presentation::Presentation, Attestation, CryptoProvider, Secrets};
use tlsn_examples::ExampleType;
use tlsn_formats::http::HttpTranscript;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// What data to notarize
#[clap(default_value_t, value_enum)]
example_type: ExampleType,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
create_presentation(&args.example_type).await
}
async fn create_presentation(example_type: &ExampleType) -> Result<(), Box<dyn std::error::Error>> {
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
let secrets_path = tlsn_examples::get_file_path(example_type, "secrets");
// Read attestation from disk.
let attestation: Attestation = bincode::deserialize(&std::fs::read(attestation_path)?)?;
// Read secrets from disk.
let secrets: Secrets = bincode::deserialize(&std::fs::read(secrets_path)?)?;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(secrets.transcript())?;
// Build a transcript proof.
let mut builder = secrets.transcript_proof_builder();
// Here is where we reveal all or some of the parts we committed in `prove.rs`
// previously.
let request = &transcript.requests[0];
// Reveal the structure of the request without the headers or body.
builder.reveal_sent(&request.without_data())?;
// Reveal the request target.
builder.reveal_sent(&request.request.target)?;
// Reveal all request headers except the values of User-Agent and Authorization.
for header in &request.headers {
if !(header
.name
.as_str()
.eq_ignore_ascii_case(header::USER_AGENT.as_str())
|| header
.name
.as_str()
.eq_ignore_ascii_case(header::AUTHORIZATION.as_str()))
{
builder.reveal_sent(header)?;
} else {
builder.reveal_sent(&header.without_value())?;
}
}
// Reveal only parts of the response.
let response = &transcript.responses[0];
// Reveal the structure of the response without the headers or body.
builder.reveal_recv(&response.without_data())?;
// Reveal all response headers.
for header in &response.headers {
builder.reveal_recv(header)?;
}
let content = &response.body.as_ref().unwrap().content;
match content {
tlsn_formats::http::BodyContent::Json(json) => {
// For experimentation, reveal the entire response or just a selection.
let reveal_all = false;
if reveal_all {
builder.reveal_recv(response)?;
} else {
builder.reveal_recv(json.get("id").unwrap())?;
builder.reveal_recv(json.get("information.name").unwrap())?;
builder.reveal_recv(json.get("meta.version").unwrap())?;
}
}
tlsn_formats::http::BodyContent::Unknown(span) => {
builder.reveal_recv(span)?;
}
_ => {}
}
let transcript_proof = builder.build()?;
// Use default crypto provider to build the presentation.
let provider = CryptoProvider::default();
let mut builder = attestation.presentation_builder(&provider);
builder
.identity_proof(secrets.identity_proof())
.transcript_proof(transcript_proof);
let presentation: Presentation = builder.build()?;
let presentation_path = tlsn_examples::get_file_path(example_type, "presentation");
// Write the presentation to disk.
std::fs::write(&presentation_path, bincode::serialize(&presentation)?)?;
println!("Presentation built successfully!");
println!("The presentation has been written to `{presentation_path}`.");
Ok(())
}

View File

@@ -0,0 +1,396 @@
// This example demonstrates how to use the Prover to acquire an attestation for
// an HTTP request sent to a server fixture. The attestation and secrets are
// saved to disk.
use std::env;
use clap::Parser;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use spansy::Spanned;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::oneshot::{self, Receiver, Sender},
};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::info;
use tlsn::{
attestation::{
request::{Request as AttestationRequest, RequestConfig},
signing::Secp256k1Signer,
Attestation, AttestationConfig, CryptoProvider, Secrets,
},
config::{
CertificateDer, PrivateKeyDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore,
},
connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength},
prover::{state::Committed, ProveConfig, Prover, ProverConfig, ProverOutput, TlsConfig},
transcript::{ContentType, TranscriptCommitConfig},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_examples::ExampleType;
use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::{CA_CERT_DER, CLIENT_CERT_DER, CLIENT_KEY_DER, SERVER_DOMAIN};
// Setting of the application server.
const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36";
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// What data to notarize.
#[clap(default_value_t, value_enum)]
example_type: ExampleType,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let args = Args::parse();
let (uri, extra_headers) = match args.example_type {
ExampleType::Json => ("/formats/json", vec![]),
ExampleType::Html => ("/formats/html", vec![]),
ExampleType::Authenticated => ("/protected", vec![("Authorization", "random_auth_token")]),
};
let (notary_socket, prover_socket) = tokio::io::duplex(1 << 23);
let (request_tx, request_rx) = oneshot::channel();
let (attestation_tx, attestation_rx) = oneshot::channel();
tokio::spawn(async move {
notary(notary_socket, request_rx, attestation_tx)
.await
.unwrap()
});
prover(
prover_socket,
request_tx,
attestation_rx,
uri,
extra_headers,
&args.example_type,
)
.await?;
Ok(())
}
async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S,
req_tx: Sender<AttestationRequest>,
resp_rx: Receiver<Attestation>,
uri: &str,
extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType,
) -> Result<(), Box<dyn std::error::Error>> {
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
));
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
// We must configure the amount of data we expect to exchange beforehand, which will
// be preprocessed prior to the connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create a new prover and perform necessary setup.
let prover = Prover::new(prover_config).setup(socket.compat()).await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
// Bind the prover to the server connection.
// The returned `mpc_tls_connection` is an MPC TLS connection to the server: all
// data written to/read from it will be encrypted/decrypted using MPC with
// the notary.
let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?;
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover_fut);
// Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
// Spawn the HTTP task to be run concurrently in the background.
tokio::spawn(connection);
// Build a simple HTTP request with common headers.
let request_builder = Request::builder()
.uri(uri)
.header("Host", SERVER_DOMAIN)
.header("Accept", "*/*")
// Using "identity" instructs the Server not to use compression for its HTTP response.
// TLSNotary tooling does not support compression.
.header("Accept-Encoding", "identity")
.header("Connection", "close")
.header("User-Agent", USER_AGENT);
let mut request_builder = request_builder;
for (key, value) in extra_headers {
request_builder = request_builder.header(key, value);
}
let request = request_builder.body(Empty::<Bytes>::new())?;
info!("Starting an MPC TLS connection with the server");
// Send the request to the server and wait for the response.
let response = request_sender.send_request(request).await?;
info!("Got a response from the server: {}", response.status());
assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it.
let mut prover = prover_task.await??;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?;
let body_content = &transcript.responses[0].body.as_ref().unwrap().content;
let body = String::from_utf8_lossy(body_content.span().as_bytes());
match body_content {
tlsn_formats::http::BodyContent::Json(_json) => {
let parsed = serde_json::from_str::<serde_json::Value>(&body)?;
info!("{}", serde_json::to_string_pretty(&parsed)?);
}
tlsn_formats::http::BodyContent::Unknown(_span) => {
info!("{}", &body);
}
_ => {}
}
// Commit to the transcript.
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
// This commits to various parts of the transcript separately (e.g. request
// headers, response headers, response body and more). See https://docs.tlsnotary.org//protocol/commit_strategy.html
// for other strategies that can be used to generate commitments.
DefaultHttpCommitter::default().commit_transcript(&mut builder, &transcript)?;
let transcript_commit = builder.build()?;
// Build an attestation request.
let mut builder = RequestConfig::builder();
builder.transcript_commit(transcript_commit);
// Optionally, add an extension to the attestation if the notary supports it.
// builder.extension(Extension {
// id: b"example.name".to_vec(),
// value: b"Bobert".to_vec(),
// });
let request_config = builder.build()?;
let (attestation, secrets) = notarize(&mut prover, &request_config, req_tx, resp_rx).await?;
// Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
let secrets_path = tlsn_examples::get_file_path(example_type, "secrets");
tokio::fs::write(&attestation_path, bincode::serialize(&attestation)?).await?;
// Write the secrets to disk.
tokio::fs::write(&secrets_path, bincode::serialize(&secrets)?).await?;
println!("Notarization completed successfully!");
println!(
"The attestation has been written to `{attestation_path}` and the \
corresponding secrets to `{secrets_path}`."
);
Ok(())
}
async fn notarize(
prover: &mut Prover<Committed>,
config: &RequestConfig,
request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets), Box<dyn std::error::Error>> {
let mut builder = ProveConfig::builder(prover.transcript());
if let Some(config) = config.transcript_commit() {
builder.transcript_commit(config.clone());
}
let disclosure_config = builder.build()?;
let ProverOutput {
transcript_commitments,
transcript_secrets,
..
} = prover.prove(disclosure_config).await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);
builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.handshake_data(HandshakeData {
certs: prover
.tls_transcript()
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: prover
.tls_transcript()
.server_signature()
.expect("server signature is present")
.clone(),
binding: prover.tls_transcript().certificate_binding().clone(),
})
.transcript(prover.transcript().clone())
.transcript_commitments(transcript_secrets, transcript_commitments);
let (request, secrets) = builder.build(&CryptoProvider::default())?;
// Send attestation request to notary.
request_tx
.send(request.clone())
.map_err(|_| "notary is not receiving attestation request".to_string())?;
// Receive attestation from notary.
let attestation = attestation_rx
.await
.map_err(|err| format!("notary did not respond with attestation: {err}"))?;
// Check the attestation is consistent with the Prover's view.
request.validate(&attestation)?;
Ok((attestation, secrets))
}
async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S,
request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>,
) -> Result<(), Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()
.unwrap();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
let mut verifier = Verifier::new(verifier_config)
.setup(socket.compat())
.await?
.run()
.await?;
let VerifierOutput {
transcript_commitments,
..
} = verifier.verify(&VerifyConfig::default()).await?;
let tls_transcript = verifier.tls_transcript().clone();
verifier.close().await?;
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Receive attestation request from prover.
let request = request_rx.await?;
// Load a dummy signing key.
let signing_key = k256::ecdsa::SigningKey::from_bytes(&[1u8; 32].into())?;
let signer = Box::new(Secp256k1Signer::new(&signing_key.to_bytes())?);
let mut provider = CryptoProvider::default();
provider.signer.set_signer(signer);
// Build an attestation.
let mut att_config_builder = AttestationConfig::builder();
att_config_builder.supported_signature_algs(Vec::from_iter(provider.signer.supported_algs()));
let att_config = att_config_builder.build()?;
let mut builder = Attestation::builder(&att_config).accept_request(request)?;
builder
.connection_info(ConnectionInfo {
time: tls_transcript.time(),
version: (*tls_transcript.version()),
transcript_length: TranscriptLength {
sent: sent_len as u32,
received: recv_len as u32,
},
})
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments);
let attestation = builder.build(&provider)?;
// Send attestation to prover.
attestation_tx
.send(attestation)
.map_err(|_| "prover is not receiving attestation".to_string())?;
Ok(())
}

View File

@@ -0,0 +1,96 @@
// This example demonstrates how to verify a presentation. See `present.rs` for
// an example of how to build a presentation from an attestation and connection
// secrets.
use std::time::Duration;
use clap::Parser;
use tlsn::{
attestation::{
presentation::{Presentation, PresentationOutput},
signing::VerifyingKey,
CryptoProvider,
},
config::{CertificateDer, RootCertStore},
verifier::ServerCertVerifier,
};
use tlsn_examples::ExampleType;
use tlsn_server_fixture_certs::CA_CERT_DER;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// What data to notarize.
#[clap(default_value_t, value_enum)]
example_type: ExampleType,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
verify_presentation(&args.example_type).await
}
async fn verify_presentation(example_type: &ExampleType) -> Result<(), Box<dyn std::error::Error>> {
// Read the presentation from disk.
let presentation_path = tlsn_examples::get_file_path(example_type, "presentation");
let presentation: Presentation = bincode::deserialize(&std::fs::read(presentation_path)?)?;
// Create a crypto provider accepting the server-fixture's self-signed
// root certificate.
//
// This is only required for offline testing with the server-fixture. In
// production, use `CryptoProvider::default()` instead.
let root_cert_store = RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
};
let crypto_provider = CryptoProvider {
cert: ServerCertVerifier::new(&root_cert_store)?,
..Default::default()
};
let VerifyingKey {
alg,
data: key_data,
} = presentation.verifying_key();
println!(
"Verifying presentation with {alg} key: {}\n\n**Ask yourself, do you trust this key?**\n",
hex::encode(key_data)
);
// Verify the presentation.
let PresentationOutput {
server_name,
connection_info,
transcript,
// extensions, // Optionally, verify any custom extensions from prover/notary.
..
} = presentation.verify(&crypto_provider).unwrap();
// The time at which the connection was started.
let time = chrono::DateTime::UNIX_EPOCH + Duration::from_secs(connection_info.time);
let server_name = server_name.unwrap();
let mut partial_transcript = transcript.unwrap();
// Set the unauthenticated bytes so they are distinguishable.
partial_transcript.set_unauthed(b'X');
let sent = String::from_utf8_lossy(partial_transcript.sent_unsafe());
let recv = String::from_utf8_lossy(partial_transcript.received_unsafe());
println!("-------------------------------------------------------------------");
println!(
"Successfully verified that the data below came from a session with {server_name} at {time}.",
);
println!("Note that the data which the Prover chose not to disclose are shown as X.\n");
println!("Data sent:\n");
println!("{sent}\n");
println!("Data received:\n");
println!("{recv}\n");
println!("-------------------------------------------------------------------");
Ok(())
}

View File

@@ -10,15 +10,15 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::PartialTranscript,
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::SERVER_DOMAIN;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
const SECRET: &str = "TLSNotary's private key 🤡";
@@ -72,18 +72,16 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(server_domain)
.server_name(ServerName::Dns(server_domain.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -176,7 +174,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}
@@ -194,13 +192,10 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let verifier_config = VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
@@ -234,6 +229,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap_or_else(|| panic!("Expected valid data from {SERVER_DOMAIN}"));
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
transcript

View File

@@ -26,7 +26,7 @@ pub enum Id {
One,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum IoMode {
Client,
Server,

View File

@@ -7,12 +7,12 @@ docker build --pull -t tlsn-bench . -f ./crates/harness/harness.Dockerfile
Next run the benches with:
```
docker run -it --privileged -v ./crates/harness/:/benches tlsn-bench bash -c "runner setup; runner bench"
docker run -it --privileged -v $(pwd)/crates/harness/:/benches tlsn-bench bash -c "runner setup; runner bench"
```
The `--privileged` parameter is required because this test bench needs permission to create networks with certain parameters
To run the benches in a browser run:
```
docker run -it --privileged -v ./crates/harness/:/benches tlsn-bench bash -c "cd /; runner setup; runner --target browser bench"
docker run -it --privileged -v $(pwd)/crates/harness/:/benches tlsn-bench bash -c "runner setup; runner --target browser bench"
```

View File

@@ -8,14 +8,9 @@ publish = false
name = "harness_executor"
crate-type = ["cdylib", "rlib"]
[package.metadata.wasm-pack.profile.custom]
wasm-opt = ["-O3"]
[dependencies]
tlsn-harness-core = { workspace = true }
tlsn = { workspace = true }
tlsn-core = { workspace = true }
tlsn-tls-core = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
inventory = { workspace = true }
@@ -33,6 +28,8 @@ tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true, features = ["compat"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
# Disable tracing events as a workaround for issue 959.
tracing = { workspace = true, features = ["release_max_level_off"] }
wasm-bindgen = { workspace = true }
tlsn-wasm = { workspace = true }
js-sys = { workspace = true }

View File

@@ -5,7 +5,8 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
use harness_core::bench::{Bench, ProverMetrics};
use tlsn::{
config::ProtocolConfig,
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
};
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -32,20 +33,17 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let protocol_config = builder.build()?;
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
let prover = Prover::new(
ProverConfig::builder()
.tls_config(tls_config)
.protocol_config(protocol_config)
.server_name(SERVER_DOMAIN)
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.build()?,
);
@@ -95,7 +93,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -2,7 +2,7 @@ use anyhow::Result;
use harness_core::bench::Bench;
use tlsn::{
config::ProtocolConfigValidator,
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifyConfig},
};
use tlsn_server_fixture_certs::CA_CERT_DER;
@@ -17,14 +17,11 @@ pub async fn bench_verifier(provider: &IoProvider, config: &Bench) -> Result<()>
let protocol_config = builder.build()?;
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(protocol_config)
.build()?,
);

View File

@@ -81,7 +81,11 @@ mod native {
mod wasm {
use super::IoProvider;
use crate::io::Io;
use anyhow::Result;
use anyhow::{Result, anyhow};
use std::time::Duration;
const CHECK_WS_OPEN_DELAY_MS: usize = 50;
const MAX_RETRIES: usize = 50;
impl IoProvider {
/// Provides a connection to the server.
@@ -107,7 +111,27 @@ mod wasm {
&self.config.proto_1.0,
self.config.proto_1.1,
);
let (_, io) = ws_stream_wasm::WsMeta::connect(url, None).await?;
let mut retries = 0;
let io = loop {
// Connect to the websocket relay.
let (_, io) = ws_stream_wasm::WsMeta::connect(url.clone(), None).await?;
// Allow some time for the relay to initiate a connection to
// the verifier.
std::thread::sleep(Duration::from_millis(CHECK_WS_OPEN_DELAY_MS as u64));
// If the relay didn't close the io, most likely the verifier
// accepted the connection.
if io.ready_state() == ws_stream_wasm::WsState::Open {
break io;
}
retries += 1;
if retries > MAX_RETRIES {
return Err(anyhow!("verifier did not accept connection"));
}
};
Ok(io.into_io())
}

View File

@@ -1,6 +1,6 @@
use tls_core::anchors::RootCertStore;
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind},
@@ -21,19 +21,17 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) {
let mut root_store = RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(SERVER_DOMAIN)
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -109,16 +107,11 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}
async fn verifier(provider: &IoProvider) {
let mut root_store = RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let config = VerifierConfig::builder()
.protocol_config_validator(
ProtocolConfigValidator::builder()
@@ -127,7 +120,9 @@ async fn verifier(provider: &IoProvider) {
.build()
.unwrap(),
)
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap();
@@ -145,7 +140,9 @@ async fn verifier(provider: &IoProvider) {
.await
.unwrap();
assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN);
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(
transcript_commitments
.iter()

View File

@@ -8,6 +8,7 @@ use chromiumoxide::{
network::{EnableParams, SetCacheDisabledParams},
page::ReloadParams,
},
handler::HandlerConfig,
};
use futures::StreamExt;
use harness_core::{
@@ -126,8 +127,18 @@ impl Executor {
const TIMEOUT: usize = 10000;
const DELAY: usize = 100;
let mut retries = 0;
let config = HandlerConfig {
// Bump the timeout for long-running benches.
request_timeout: Duration::from_secs(120),
..Default::default()
};
let (browser, mut handler) = loop {
match Browser::connect(format!("http://{}:{}", rpc_addr.0, PORT_BROWSER)).await
match Browser::connect_with_config(
format!("http://{}:{}", rpc_addr.0, PORT_BROWSER),
config.clone(),
)
.await
{
Ok(browser) => break browser,
Err(e) => {
@@ -143,6 +154,14 @@ impl Executor {
tokio::spawn(async move {
while let Some(res) = handler.next().await {
if let Err(e) = res {
if e.to_string()
== "data did not match any variant of untagged enum Message"
{
// Do not log this error. It appears to be
// caused by a bug upstream.
// https://github.com/mattsse/chromiumoxide/issues/167
continue;
}
eprintln!("chromium error: {e:?}");
}
}

View File

@@ -72,3 +72,5 @@ tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["compat"] }
tracing-subscriber = { workspace = true }
uid-mux = { workspace = true, features = ["serio", "test-utils"] }
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }

View File

@@ -22,7 +22,7 @@ use serio::stream::IoStreamExt;
use std::mem;
use tls_core::msgs::enums::NamedGroup;
use tlsn_core::{
connection::{HandshakeData, HandshakeDataV1_2, TlsVersion, VerifyData},
connection::{CertBinding, CertBindingV1_2, TlsVersion, VerifyData},
transcript::TlsTranscript,
};
use tracing::{debug, instrument};
@@ -405,7 +405,7 @@ impl MpcTlsFollower {
let cf_vd = cf_vd.ok_or(MpcTlsError::hs("client finished VD not computed"))?;
let sf_vd = sf_vd.ok_or(MpcTlsError::hs("server finished VD not computed"))?;
let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 {
let handshake_data = CertBinding::V1_2(CertBindingV1_2 {
client_random,
server_random,
server_ephemeral_key: server_key

View File

@@ -43,10 +43,9 @@ use tls_core::{
suites::SupportedCipherSuite,
};
use tlsn_core::{
connection::{
Certificate, HandshakeData, HandshakeDataV1_2, ServerSignature, TlsVersion, VerifyData,
},
connection::{CertBinding, CertBindingV1_2, ServerSignature, TlsVersion, VerifyData},
transcript::TlsTranscript,
webpki::CertificateDer,
};
use tracing::{debug, instrument, trace, warn};
@@ -325,7 +324,7 @@ impl MpcTlsLeader {
let server_cert_chain = server_cert_details
.cert_chain()
.iter()
.map(|cert| Certificate(cert.0.clone()))
.map(|cert| CertificateDer(cert.0.clone()))
.collect();
let server_signature = ServerSignature {
@@ -337,7 +336,7 @@ impl MpcTlsLeader {
sig: server_kx_details.kx_sig().sig.0.clone(),
};
let handshake_data = HandshakeData::V1_2(HandshakeDataV1_2 {
let handshake_data = CertBinding::V1_2(CertBindingV1_2 {
client_random: client_random.0,
server_random: server_random.0,
server_ephemeral_key: server_key

View File

@@ -12,11 +12,15 @@ use mpz_ot::{
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_client::Certificate;
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN};
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "expensive"]
@@ -48,11 +52,11 @@ async fn leader_task(mut leader: MpcTlsLeader) {
let (leader_ctrl, leader_fut) = leader.run();
tokio::spawn(async { leader_fut.await.unwrap() });
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_root_certificates(RootCertStore {
roots: vec![anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()],
})
.with_no_client_auth();
let server_name = SERVER_DOMAIN.try_into().unwrap();

View File

@@ -6,7 +6,7 @@ pub static SERVER_CERT_DER: &[u8] = include_bytes!("tls/test_server_cert.der");
pub static SERVER_KEY_DER: &[u8] = include_bytes!("tls/test_server_private_key.der");
/// The domain name bound to the server certificate.
pub static SERVER_DOMAIN: &str = "test-server.io";
/// A client certificate fixture PEM-encoded.
pub static CLIENT_CERT: &[u8] = include_bytes!("tls/client_cert.pem");
/// A client private key fixture PEM-encoded.
pub static CLIENT_KEY: &[u8] = include_bytes!("tls/client_cert.key");
/// A client certificate fixture.
pub static CLIENT_CERT_DER: &[u8] = include_bytes!("tls/client_cert.der");
/// A client private key fixture.
pub static CLIENT_KEY_DER: &[u8] = include_bytes!("tls/client_cert_private_key.der");

View File

@@ -33,5 +33,8 @@ openssl req -new -key client_cert.key -out client_cert.csr -subj "/C=US/ST=State
# Sign the CSR with the root CA to create the end entity certificate (100 years validity)
openssl x509 -req -in client_cert.csr -CA root_ca.crt -CAkey root_ca.key -CAcreateserial -out client_cert.crt -days 36525 -sha256 -extfile openssl.cnf -extensions v3_req
# Convert the end entity certificate to PEM format
openssl x509 -in client_cert.crt -outform pem -out client_cert.pem
# Convert the end entity certificate to DER format
openssl x509 -in client_cert.crt -outform der -out client_cert.der
# Convert the end entity certificate private key to DER format
openssl pkcs8 -topk8 -inform PEM -outform DER -in client_cert.key -out client_cert_private_key.der -nocrypt

Binary file not shown.

View File

@@ -1,23 +0,0 @@
-----BEGIN CERTIFICATE-----
MIID2jCCAsKgAwIBAgIUG5JKIz/fbUDdpX1+TAw33mS+mWwwDQYJKoZIhvcNAQEL
BQAwZTELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5
MRIwEAYDVQQKDAl0bHNub3RhcnkxCzAJBgNVBAsMAklUMRYwFAYDVQQDDA10bHNu
b3Rhcnkub3JnMCAXDTI1MDYxMDA3MTYxOVoYDzIxMjUwNjExMDcxNjE5WjBwMQsw
CQYDVQQGEwJVUzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxEjAQBgNV
BAoMCXRsc25vdGFyeTELMAkGA1UECwwCSVQxITAfBgNVBAMMGGNsaWVudC1hdXRo
ZW50aWNhdGlvbi5pbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANsx
Tf3JqWdAMGFzOwbO64vJ5fV/IPSrdBwKY/Fjef0REZC1Z/gGzmp0nnlaHZzZLtLS
Z9kyfdUrL6PuG3HfP6wxhiaBpUay+1O9KZsuhkKSif4KMPjlYKm+oZLvD12Qj62r
TFlui4+1wKgPrTGUUO6SQdoRxKU4nzuzRYRLyzDi0pO5YD9RLaruBj+IDEOVRW7d
1uleheVMg61lbQle5Fo0c4I0Sif96Z+7aotj3j9F2lK52jaLpA1kvC3oLajfAT30
BzpNLZTnWa1b5PRRxkuOYUXeNr+aNO90fL80K1YeIlea0f7qmKL9uDLtQbrqIJv5
tBaf8Uf0UghtBm//kx8CAwEAAaN1MHMwCQYDVR0TBAIwADALBgNVHQ8EBAMCBeAw
GQYDVR0RBBIwEIIOdGVzdC1zZXJ2ZXIuaW8wHQYDVR0OBBYEFH1qCgl04Y5i75aF
cT0V3fn9423iMB8GA1UdIwQYMBaAFMmBciQ/DZlWROxwXH8IplmuHKbNMA0GCSqG
SIb3DQEBCwUAA4IBAQB8Gvj3dsENAn0u6PS9uTFm46MaA9Dm+Fa+KbXuEHp3ADs2
7m4Hb3eojM3yae93/v/stYn8IVcB5zWmMvg6WA6obe86muuB+SZeMC/AnSD8P4pm
AzO3eTSR1s5Dr4O0qVPd2VP36e7NWXfojQg4W9t9UQtC64bVOaCDQvbe0xeWT+AR
w0y7GwnuCr/8bisqQZS8+Er1JU3zxBEjQwMiMxlOWHnYtjGeA6pdWaeLp0E6Ss3x
ecsTjmrLt6oY+BdfRSyWU4qVEOpuZLCeikUWXFzpxRX7NWYRtJUfVnoRWwuD2lzG
LybzCW2qxwHJe4biGIfWKQ7Ne7DrwQwFxVRJxCm0
-----END CERTIFICATE-----

View File

@@ -35,3 +35,5 @@ hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
rstest = { workspace = true }
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
rustls-webpki = { workspace = true }
rustls-pki-types = { workspace = true }

View File

@@ -6,7 +6,8 @@ use http_body_util::{BodyExt as _, Full};
use hyper::{body::Bytes, Request, StatusCode};
use hyper_util::rt::TokioIo;
use rstest::{fixture, rstest};
use tls_client::{Certificate, ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use rustls_pki_types::CertificateDer;
use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName};
use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection};
use tls_server_fixture::{
bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY,
@@ -14,6 +15,9 @@ use tls_server_fixture::{
};
use tokio::task::JoinHandle;
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
// An established client TLS connection
struct TlsFixture {
@@ -30,7 +34,9 @@ async fn set_up_tls() -> TlsFixture {
let _server_task = tokio::spawn(bind_test_server(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
@@ -75,7 +81,9 @@ async fn test_hyper_ok() {
let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat()));
let mut root_store = tls_client::RootCertStore::empty();
root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap();
root_store
.roots
.push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned());
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)

View File

@@ -23,7 +23,8 @@ async-trait = { workspace = true }
log = { workspace = true, optional = true }
ring = { workspace = true }
sct = { workspace = true }
webpki = { workspace = true, features = ["alloc", "std"] }
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }
aes-gcm = { workspace = true }
p256 = { workspace = true, features = ["ecdh"] }
rand = { workspace = true }

View File

@@ -7,7 +7,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::HandshakeHashBuffer,
msgs::persist,
ticketer::TimeBase,
};
use tls_core::{
@@ -42,35 +41,6 @@ pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
pub(super) type NextStateOrError = Result<NextState, Error>;
pub(super) type ClientContext<'a> = crate::conn::Context<'a>;
fn find_session(
server_name: &ServerName,
config: &ClientConfig,
) -> Option<persist::Retrieved<persist::ClientSessionValue>> {
let key = persist::ClientSessionKey::session_for_server_name(server_name);
let key_buf = key.get_encoding();
let value = config.session_storage.get(&key_buf).or_else(|| {
debug!("No cached session for {:?}", server_name);
None
})?;
#[allow(unused_mut)]
let mut reader = Reader::init(&value[2..]);
#[allow(clippy::bind_instead_of_map)] // https://github.com/rust-lang/rust-clippy/issues/8082
CipherSuite::read_bytes(&value[..2])
.and_then(|suite| {
persist::ClientSessionValue::read(&mut reader, suite, &config.cipher_suites)
})
.and_then(|resuming| {
let retrieved = persist::Retrieved::new(resuming, TimeBase::now().ok()?);
match retrieved.has_expired() {
false => Some(retrieved),
true => None,
}
})
.and_then(Some)
}
pub(super) async fn start_handshake(
server_name: ServerName,
extra_exts: Vec<ClientExtension>,
@@ -123,7 +93,6 @@ pub(super) async fn start_handshake(
emit_client_hello_for_retry(
config,
cx,
None,
random,
false,
transcript_buffer,
@@ -142,7 +111,6 @@ pub(super) async fn start_handshake(
struct ExpectServerHello {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
server_name: ServerName,
random: Random,
using_ems: bool,
@@ -162,7 +130,6 @@ struct ExpectServerHelloOrHelloRetryRequest {
async fn emit_client_hello_for_retry(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
resuming_session: Option<persist::Retrieved<persist::ClientSessionValue>>,
random: Random,
using_ems: bool,
mut transcript_buffer: HandshakeHashBuffer,
@@ -176,25 +143,6 @@ async fn emit_client_hello_for_retry(
may_send_sct_list: bool,
suite: Option<SupportedCipherSuite>,
) -> Result<NextState, Error> {
// For now we do not support session resumption
//
// Do we have a SessionID or ticket cached for this host?
// let (ticket, resume_version) = if let Some(resuming) = &resuming_session {
// match &resuming.value {
// persist::ClientSessionValue::Tls13(inner) => {
// (inner.ticket().to_vec(), ProtocolVersion::TLSv1_3)
// }
// #[cfg(feature = "tls12")]
// persist::ClientSessionValue::Tls12(inner) => {
// (inner.ticket().to_vec(), ProtocolVersion::TLSv1_2)
// }
// }
// } else {
// (Vec::new(), ProtocolVersion::Unknown(0))
// };
// let (ticket, resume_version) = (Vec::new(), ProtocolVersion::Unknown(0));
let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2);
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
@@ -256,48 +204,6 @@ async fn emit_client_hello_for_retry(
// Extra extensions must be placed before the PSK extension
exts.extend(extra_exts.iter().cloned());
// let fill_in_binder = if support_tls13
// && config.enable_tickets
// && resume_version == ProtocolVersion::TLSv1_3
// && !ticket.is_empty()
// {
// let resuming =
// resuming_session
// .as_ref()
// .and_then(|resuming| match (suite, resuming.tls13()) {
// (Some(suite), Some(resuming)) => {
// suite.tls13()?.can_resume_from(resuming.suite())?;
// Some(resuming)
// }
// (None, Some(resuming)) => Some(resuming),
// _ => None,
// });
// if let Some(ref resuming) = resuming {
// tls13::prepare_resumption(
// &config,
// cx,
// ticket,
// &resuming,
// &mut exts,
// retryreq.is_some(),
// )
// .await;
// }
// resuming
// } else if config.enable_tickets {
// // If we have a ticket, include it. Otherwise, request one.
// if ticket.is_empty() {
// exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Request));
// } else {
// exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Offer(
// Payload::new(ticket),
// )));
// }
// None
// } else {
// None
// };
// Note what extensions we sent.
hello.sent_extensions = exts.iter().map(ClientExtension::get_type).collect();
@@ -319,8 +225,8 @@ async fn emit_client_hello_for_retry(
};
// let early_key_schedule = if let Some(resuming) = fill_in_binder {
// let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer, &mut chp);
// Some((resuming.suite(), schedule))
// let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer,
// &mut chp); Some((resuming.suite(), schedule))
// } else {
// None
// };
@@ -350,7 +256,6 @@ async fn emit_client_hello_for_retry(
let next = ExpectServerHello {
config,
resuming_session,
server_name,
random,
using_ems,
@@ -551,19 +456,10 @@ impl State<ClientConnectionData> for ExpectServerHello {
// handshake_traffic_secret.
match suite {
SupportedCipherSuite::Tls13(suite) => {
let resuming_session =
self.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls13(inner) => Some(inner),
#[cfg(feature = "tls12")]
persist::ClientSessionValue::Tls12(_) => None,
});
tls13::handle_server_hello(
self.config,
cx,
server_hello,
resuming_session,
self.server_name,
randoms,
suite,
@@ -577,16 +473,8 @@ impl State<ClientConnectionData> for ExpectServerHello {
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(suite) => {
let resuming_session =
self.resuming_session
.and_then(|resuming| match resuming.value {
persist::ClientSessionValue::Tls12(inner) => Some(inner),
persist::ClientSessionValue::Tls13(_) => None,
});
tls12::CompleteServerHelloHandling {
config: self.config,
resuming_session,
server_name: self.server_name,
randoms,
using_ems: self.using_ems,
@@ -723,7 +611,6 @@ impl ExpectServerHelloOrHelloRetryRequest {
emit_client_hello_for_retry(
self.next.config,
cx,
self.next.resuming_session,
self.next.random,
self.next.using_ems,
transcript_buffer,

View File

@@ -10,7 +10,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::HandshakeHash,
msgs::persist,
sign::Signer,
ticketer::TimeBase,
verify,
@@ -49,7 +48,6 @@ mod server_hello {
pub(in crate::client) struct CompleteServerHelloHandling {
pub(in crate::client) config: Arc<ClientConfig>,
pub(in crate::client) resuming_session: Option<persist::Tls12ClientSessionValue>,
pub(in crate::client) server_name: ServerName,
pub(in crate::client) randoms: ConnectionRandoms,
pub(in crate::client) using_ems: bool,
@@ -113,76 +111,8 @@ mod server_hello {
None
};
// See if we're successfully resuming.
if let Some(ref _resuming) = self.resuming_session {
return Err(Error::General(
"client does not support resumption".to_string(),
));
// if resuming.session_id == server_hello.session_id {
// debug!("Server agreed to resume");
// // Is the server telling lies about the ciphersuite?
// if resuming.suite() != suite {
// let error_msg =
// "abbreviated handshake offered, but with varied cs".to_string();
// return Err(Error::PeerMisbehavedError(error_msg));
// }
// // And about EMS support?
// if resuming.extended_ms() != self.using_ems {
// let error_msg = "server varied ems support over resume".to_string();
// return Err(Error::PeerMisbehavedError(error_msg));
// }
// let secrets =
// ConnectionSecrets::new_resume(self.randoms, suite, resuming.secret());
// self.config.key_log.log(
// "CLIENT_RANDOM",
// &secrets.randoms.client,
// &secrets.master_secret,
// );
// cx.common.start_encryption_tls12(&secrets, Side::Client);
// // Since we're resuming, we verified the certificate and
// // proof of possession in the prior session.
// cx.common.peer_certificates = Some(resuming.server_cert_chain().to_vec());
// let cert_verified = verify::ServerCertVerified::assertion();
// let sig_verified = verify::HandshakeSignatureValid::assertion();
// return if must_issue_new_ticket {
// Ok(Box::new(ExpectNewTicket {
// config: self.config,
// secrets,
// resuming_session: self.resuming_session,
// session_id: server_hello.session_id,
// server_name: self.server_name,
// using_ems: self.using_ems,
// transcript: self.transcript,
// resuming: true,
// cert_verified,
// sig_verified,
// }))
// } else {
// Ok(Box::new(ExpectCcs {
// config: self.config,
// secrets,
// resuming_session: self.resuming_session,
// session_id: server_hello.session_id,
// server_name: self.server_name,
// using_ems: self.using_ems,
// transcript: self.transcript,
// ticket: None,
// resuming: true,
// cert_verified,
// sig_verified,
// }))
// };
// }
}
Ok(Box::new(ExpectCertificate {
config: self.config,
resuming_session: self.resuming_session,
session_id: server_hello.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -199,7 +129,6 @@ mod server_hello {
struct ExpectCertificate {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -228,7 +157,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
if self.may_send_cert_status {
Ok(Box::new(ExpectCertificateStatusOrServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -250,7 +178,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
Ok(Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -266,7 +193,6 @@ impl State<ClientConnectionData> for ExpectCertificate {
struct ExpectCertificateStatusOrServerKx {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -303,7 +229,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -322,7 +247,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
}) => {
Box::new(ExpectCertificateStatus {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -350,7 +274,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatusOrServerKx {
struct ExpectCertificateStatus {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -395,7 +318,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {
Ok(Box::new(ExpectServerKx {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -410,7 +332,6 @@ impl State<ClientConnectionData> for ExpectCertificateStatus {
struct ExpectServerKx {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -458,7 +379,6 @@ impl State<ClientConnectionData> for ExpectServerKx {
Ok(Box::new(ExpectServerDoneOrCertReq {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -570,7 +490,6 @@ async fn emit_finished(
// client auth. Otherwise we go straight to ServerHelloDone.
struct ExpectServerDoneOrCertReq {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -598,7 +517,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
) {
Box::new(ExpectCertificateRequest {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -616,7 +534,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
Box::new(ExpectServerDone {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -636,7 +553,6 @@ impl State<ClientConnectionData> for ExpectServerDoneOrCertReq {
struct ExpectCertificateRequest {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -679,7 +595,6 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
Ok(Box::new(ExpectServerDone {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
randoms: self.randoms,
@@ -696,7 +611,6 @@ impl State<ClientConnectionData> for ExpectCertificateRequest {
struct ExpectServerDone {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
randoms: ConnectionRandoms,
@@ -745,6 +659,7 @@ impl State<ClientConnectionData> for ExpectServerDone {
// 3. Verify that the top certificate signed their kx.
// 4. If doing client auth, send our Certificate.
// 5. Complete the key exchange:
//
// a) generate our kx pair
// b) emit a ClientKeyExchange containing it
// c) if doing client auth, emit a CertificateVerify
@@ -891,7 +806,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
if st.must_issue_new_ticket {
Ok(Box::new(ExpectNewTicket {
config: st.config,
resuming_session: st.resuming_session,
session_id: st.session_id,
server_name: st.server_name,
using_ems: st.using_ems,
@@ -903,7 +817,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
} else {
Ok(Box::new(ExpectCcs {
config: st.config,
resuming_session: st.resuming_session,
session_id: st.session_id,
server_name: st.server_name,
using_ems: st.using_ems,
@@ -919,7 +832,6 @@ impl State<ClientConnectionData> for ExpectServerDone {
struct ExpectNewTicket {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -946,7 +858,6 @@ impl State<ClientConnectionData> for ExpectNewTicket {
Ok(Box::new(ExpectCcs {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
using_ems: self.using_ems,
@@ -962,7 +873,6 @@ impl State<ClientConnectionData> for ExpectNewTicket {
// -- Waiting for their CCS --
struct ExpectCcs {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -998,7 +908,6 @@ impl State<ClientConnectionData> for ExpectCcs {
Ok(Box::new(ExpectFinished {
config: self.config,
resuming_session: self.resuming_session,
session_id: self.session_id,
server_name: self.server_name,
using_ems: self.using_ems,
@@ -1013,7 +922,6 @@ impl State<ClientConnectionData> for ExpectCcs {
struct ExpectFinished {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls12ClientSessionValue>,
session_id: SessionID,
server_name: ServerName,
using_ems: bool,
@@ -1024,60 +932,6 @@ struct ExpectFinished {
sig_verified: verify::HandshakeSignatureValid,
}
// impl ExpectFinished {
// // -- Waiting for their finished --
// fn save_session(&mut self, cx: &mut ClientContext<'_>) {
// // Save a ticket. If we got a new ticket, save that. Otherwise, save the
// // original ticket again.
// let (mut ticket, lifetime) = match self.ticket.take() {
// Some(nst) => (nst.ticket.0, nst.lifetime_hint),
// None => (Vec::new(), 0),
// };
// if ticket.is_empty() {
// if let Some(resuming_session) = &mut self.resuming_session {
// ticket = resuming_session.take_ticket();
// }
// }
// if self.session_id.is_empty() && ticket.is_empty() {
// debug!("Session not saved: server didn't allocate id or ticket");
// return;
// }
// let time_now = match TimeBase::now() {
// Ok(time_now) => time_now,
// Err(e) => {
// debug!("Session not saved: {}", e);
// return;
// }
// };
// let key = persist::ClientSessionKey::session_for_server_name(&self.server_name);
// let value = persist::Tls12ClientSessionValue::new(
// self.secrets.suite(),
// self.session_id,
// ticket,
// self.secrets.get_master_secret(),
// cx.common.peer_certificates.clone().unwrap_or_default(),
// time_now,
// lifetime,
// self.using_ems,
// );
// let worked = self
// .config
// .session_storage
// .put(key.get_encoding(), value.get_encoding());
// if worked {
// debug!("Session saved");
// } else {
// debug!("Session not saved");
// }
// }
// }
#[async_trait]
impl State<ClientConnectionData> for ExpectFinished {
async fn handle(

View File

@@ -11,7 +11,6 @@ use crate::{
conn::{CommonState, ConnectionRandoms, State},
error::Error,
hash_hs::{HandshakeHash, HandshakeHashBuffer},
msgs::persist,
sign, verify, KeyLog,
};
#[allow(deprecated)]
@@ -60,7 +59,6 @@ pub(super) async fn handle_server_hello(
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
server_hello: &ServerHelloPayload,
resuming_session: Option<persist::Tls13ClientSessionValue>,
server_name: ServerName,
randoms: ConnectionRandoms,
suite: &'static Tls13CipherSuite,
@@ -102,8 +100,8 @@ pub(super) async fn handle_server_hello(
// };
// if server_hello.get_psk_index() != Some(0) {
// return Err(cx.common.illegal_param("server selected invalid psk").await);
// }
// return Err(cx.common.illegal_param("server selected invalid
// psk").await); }
// debug!("Resuming using PSK");
// // The key schedule has been initialized and set in fill_in_psk_binder()
@@ -143,7 +141,6 @@ pub(super) async fn handle_server_hello(
Ok(Box::new(ExpectEncryptedExtensions {
config,
resuming_session,
server_name,
randoms,
suite,
@@ -170,69 +167,6 @@ async fn validate_server_hello(
Ok(())
}
// fn save_kx_hint(config: &ClientConfig, server_name: &ServerName, group: NamedGroup) {
// let key = persist::ClientSessionKey::hint_for_server_name(server_name);
// config
// .session_storage
// .put(key.get_encoding(), group.get_encoding());
// }
// /// This implements the horrifying TLS1.3 hack where PSK binders have a
// /// data dependency on the message they are contained within.
// pub(super) fn fill_in_psk_binder(
// resuming: &persist::Tls13ClientSessionValue,
// transcript: &HandshakeHashBuffer,
// hmp: &mut HandshakeMessagePayload,
// ) -> KeyScheduleEarly {
// // We need to know the hash function of the suite we're trying to resume into.
// let hkdf_alg = &resuming.suite().hkdf_algorithm;
// let suite_hash = resuming.suite().hash_algorithm();
// // The binder is calculated over the clienthello, but doesn't include itself or its
// // length, or the length of its container.
// let binder_plaintext = hmp.get_encoding_for_binder_signing();
// let handshake_hash = transcript.get_hash_given(suite_hash, &binder_plaintext);
// // Run a fake key_schedule to simulate what the server will do if it chooses
// // to resume.
// let key_schedule = KeyScheduleEarly::new(hkdf_alg, resuming.secret());
// let real_binder = key_schedule.resumption_psk_binder_key_and_sign_verify_data(&handshake_hash);
// if let HandshakePayload::ClientHello(ref mut ch) = hmp.payload {
// ch.set_psk_binder(real_binder.as_ref());
// };
// key_schedule
// }
// pub(super) async fn prepare_resumption(
// config: &ClientConfig,
// cx: &mut ClientContext<'_>,
// ticket: Vec<u8>,
// resuming_session: &persist::Retrieved<&persist::Tls13ClientSessionValue>,
// exts: &mut Vec<ClientExtension>,
// doing_retry: bool,
// ) {
// let resuming_suite = resuming_session.suite();
// cx.common.suite = Some(resuming_suite.into());
// cx.data.resumption_ciphersuite = Some(resuming_suite.into());
// // Finally, and only for TLS1.3 with a ticket resumption, include a binder
// // for our ticket. This must go last.
// //
// // Include an empty binder. It gets filled in below because it depends on
// // the message it's contained in (!!!).
// let obfuscated_ticket_age = resuming_session.obfuscated_ticket_age();
// let binder_len = resuming_suite.hash_algorithm().output_len();
// let binder = vec![0u8; binder_len];
// let psk_identity = PresharedKeyIdentity::new(ticket, obfuscated_ticket_age);
// let psk_ext = PresharedKeyOffer::new(psk_identity, binder);
// exts.push(ClientExtension::PresharedKey(psk_ext));
// }
pub(super) async fn emit_fake_ccs(
sent_tls13_fake_ccs: &mut bool,
common: &mut CommonState,
@@ -287,7 +221,6 @@ async fn validate_encrypted_extensions(
struct ExpectEncryptedExtensions {
config: Arc<ClientConfig>,
resuming_session: Option<persist::Tls13ClientSessionValue>,
server_name: ServerName,
randoms: ConnectionRandoms,
suite: &'static Tls13CipherSuite,
@@ -313,52 +246,19 @@ impl State<ClientConnectionData> for ExpectEncryptedExtensions {
validate_encrypted_extensions(cx.common, &self.hello, exts).await?;
hs::process_alpn_protocol(cx.common, &self.config, exts.get_alpn_protocol()).await?;
if let Some(resuming_session) = self.resuming_session {
let was_early_traffic = cx.common.early_traffic;
if was_early_traffic {
if exts.early_data_extension_offered() {
cx.data.early_data.accepted();
} else {
cx.data.early_data.rejected();
cx.common.early_traffic = false;
}
}
if was_early_traffic && !cx.common.early_traffic {
// If no early traffic, set the encryption key for handshakes
cx.common.record_layer.set_message_encrypter();
}
cx.common.peer_certificates = Some(resuming_session.server_cert_chain().to_vec());
// We *don't* reverify the certificate chain here: resumption is a
// continuation of the previous session in terms of security policy.
let cert_verified = verify::ServerCertVerified::assertion();
let sig_verified = verify::HandshakeSignatureValid::assertion();
Ok(Box::new(ExpectFinished {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
client_auth: None,
cert_verified,
sig_verified,
}))
} else {
if exts.early_data_extension_offered() {
let msg = "server sent early data extension without resumption".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
Ok(Box::new(ExpectCertificateOrCertReq {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
may_send_sct_list: self.hello.server_may_send_sct_list(),
}))
if exts.early_data_extension_offered() {
let msg = "server sent early data extension without resumption".to_string();
return Err(Error::PeerMisbehavedError(msg));
}
Ok(Box::new(ExpectCertificateOrCertReq {
config: self.config,
server_name: self.server_name,
randoms: self.randoms,
suite: self.suite,
transcript: self.transcript,
may_send_sct_list: self.hello.server_may_send_sct_list(),
}))
}
}
@@ -422,9 +322,9 @@ impl State<ClientConnectionData> for ExpectCertificateOrCertReq {
}
}
// TLS1.3 version of CertificateRequest handling. We then move to expecting the server
// Certificate. Unfortunately the CertificateRequest type changed in an annoying way
// in TLS1.3.
// TLS1.3 version of CertificateRequest handling. We then move to expecting the
// server Certificate. Unfortunately the CertificateRequest type changed in an
// annoying way in TLS1.3.
struct ExpectCertificateRequest {
config: Arc<ClientConfig>,
server_name: ServerName,
@@ -787,8 +687,8 @@ impl State<ClientConnectionData> for ExpectFinished {
st.transcript.add_message(&m);
/* The EndOfEarlyData message to server is still encrypted with early data keys,
* but appears in the transcript after the server Finished. */
/* The EndOfEarlyData message to server is still encrypted with early data
* keys, but appears in the transcript after the server Finished. */
if cx.common.early_traffic {
emit_end_of_early_data_tls13(&mut st.transcript, cx.common).await?;
cx.common.early_traffic = false;
@@ -893,42 +793,6 @@ impl ExpectTraffic {
));
}
// let handshake_hash = self.transcript.get_current_hash();
// let secret = self
// .key_schedule
// .resumption_master_secret_and_derive_ticket_psk(&handshake_hash, &nst.nonce.0);
// let time_now = match TimeBase::now() {
// Ok(t) => t,
// #[allow(unused_variables)]
// Err(e) => {
// debug!("Session not saved: {}", e);
// return Ok(());
// }
// };
// let value = persist::Tls13ClientSessionValue::new(
// self.suite,
// nst.ticket.0.clone(),
// secret,
// cx.common.peer_certificates.clone().unwrap_or_default(),
// time_now,
// nst.lifetime,
// nst.age_add,
// nst.get_max_early_data_size().unwrap_or_default(),
// );
// let key = persist::ClientSessionKey::session_for_server_name(&self.server_name);
// #[allow(unused_mut)]
// let mut ticket = value.get_encoding();
// let worked = self.session_storage.put(key.get_encoding(), ticket);
// if worked {
// debug!("Ticket saved");
// } else {
// debug!("Ticket not saved");
// }
Ok(())
}
@@ -948,27 +812,6 @@ impl ExpectTraffic {
Err(Error::General(
"received unsupported key update request from peer".to_string(),
))
// match kur {
// KeyUpdateRequest::UpdateNotRequested => {}
// KeyUpdateRequest::UpdateRequested => {
// self.want_write_key_update = true;
// }
// _ => {
// common
// .send_fatal_alert(AlertDescription::IllegalParameter)
// .await;
// return Err(Error::CorruptMessagePayload(ContentType::Handshake));
// }
// }
// // Update our read-side keys.
// let new_read_key = self.key_schedule.next_server_application_traffic_secret();
// common
// .record_layer
// .set_message_decrypter(self.suite.derive_decrypter(&new_read_key));
// Ok(())
}
}
@@ -1022,10 +865,11 @@ impl State<ClientConnectionData> for ExpectTraffic {
// .send_msg_encrypt(Message::build_key_update_notify().into())
// .await;
// let write_key = self.key_schedule.next_client_application_traffic_secret();
// let write_key =
// self.key_schedule.next_client_application_traffic_secret();
// common
// .record_layer
// .set_message_encrypter(self.suite.derive_encrypter(&write_key));
// }
// .set_message_encrypter(self.suite.derive_encrypter(&
// write_key)); }
}
}

View File

@@ -182,7 +182,7 @@ impl ConnectionCommon {
}
/// Returns an object that allows reading plaintext.
pub fn reader(&mut self) -> Reader {
pub fn reader(&mut self) -> Reader<'_> {
Reader {
received_plaintext: &mut self.common_state.received_plaintext,
// Are we done? i.e., have we processed all received messages, and received a

View File

@@ -301,15 +301,11 @@ mod conn;
mod error;
mod hash_hs;
mod limited_cache;
mod msgs;
mod rand;
mod record_layer;
//mod stream;
mod vecbuf;
pub(crate) use tls_core::verify;
#[cfg(test)]
mod verifybench;
pub(crate) use tls_core::x509;
pub(crate) use tls_core::{verify, x509};
#[macro_use]
mod check;
mod bs_debug;
@@ -330,7 +326,7 @@ pub mod internal {
// The public interface is:
pub use crate::{
anchors::{OwnedTrustAnchor, RootCertStore},
anchors::RootCertStore,
builder::{ConfigBuilder, WantsCipherSuites, WantsKxGroups, WantsVerifier, WantsVersions},
conn::{CommonState, ConnectionCommon, IoState, Reader, SideData},
error::Error,

View File

@@ -1,4 +0,0 @@
pub(crate) mod persist;
#[cfg(test)]
mod persist_test;

View File

@@ -1,526 +0,0 @@
use crate::{client::ServerName, ticketer::TimeBase};
use std::cmp;
#[cfg(feature = "tls12")]
use std::mem;
#[cfg(feature = "tls12")]
use tls_core::suites::Tls12CipherSuite;
use tls_core::{
msgs::{
base::{PayloadU16, PayloadU8},
codec::{Codec, Reader},
enums::{CipherSuite, ProtocolVersion},
handshake::{CertificatePayload, SessionID},
},
suites::{SupportedCipherSuite, Tls13CipherSuite},
};
// These are the keys and values we store in session storage.
// --- Client types ---
/// Keys for session resumption and tickets.
/// Matching value is a `ClientSessionValue`.
#[derive(Debug)]
pub struct ClientSessionKey {
kind: &'static [u8],
name: Vec<u8>,
}
impl Codec for ClientSessionKey {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(self.kind);
bytes.extend_from_slice(&self.name);
}
// Don't need to read these.
fn read(_r: &mut Reader) -> Option<Self> {
None
}
}
impl ClientSessionKey {
pub fn session_for_server_name(server_name: &ServerName) -> Self {
Self {
kind: b"session",
name: server_name.encode(),
}
}
pub fn hint_for_server_name(server_name: &ServerName) -> Self {
Self {
kind: b"kx-hint",
name: server_name.encode(),
}
}
}
#[derive(Debug)]
pub enum ClientSessionValue {
Tls13(Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(Tls12ClientSessionValue),
}
impl ClientSessionValue {
pub fn read(
reader: &mut Reader<'_>,
suite: CipherSuite,
supported: &[SupportedCipherSuite],
) -> Option<Self> {
match supported.iter().find(|s| s.suite() == suite)? {
SupportedCipherSuite::Tls13(inner) => {
Tls13ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls13)
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(inner) => {
Tls12ClientSessionValue::read(inner, reader).map(ClientSessionValue::Tls12)
}
}
}
fn common(&self) -> &ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
}
impl From<Tls13ClientSessionValue> for ClientSessionValue {
fn from(v: Tls13ClientSessionValue) -> Self {
Self::Tls13(v)
}
}
#[cfg(feature = "tls12")]
impl From<Tls12ClientSessionValue> for ClientSessionValue {
fn from(v: Tls12ClientSessionValue) -> Self {
Self::Tls12(v)
}
}
pub struct Retrieved<T> {
pub value: T,
retrieved_at: TimeBase,
}
impl<T> Retrieved<T> {
pub fn new(value: T, retrieved_at: TimeBase) -> Self {
Self {
value,
retrieved_at,
}
}
}
impl Retrieved<&Tls13ClientSessionValue> {
pub fn obfuscated_ticket_age(&self) -> u32 {
let age_secs = self
.retrieved_at
.as_secs()
.saturating_sub(self.value.common.epoch);
let age_millis = age_secs as u32 * 1000;
age_millis.wrapping_add(self.value.age_add)
}
}
impl Retrieved<ClientSessionValue> {
pub fn tls13(&self) -> Option<Retrieved<&Tls13ClientSessionValue>> {
match &self.value {
ClientSessionValue::Tls13(value) => Some(Retrieved::new(value, self.retrieved_at)),
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(_) => None,
}
}
pub fn has_expired(&self) -> bool {
let common = self.value.common();
common.lifetime_secs != 0
&& common.epoch + u64::from(common.lifetime_secs) < self.retrieved_at.as_secs()
}
}
impl<T> std::ops::Deref for Retrieved<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
#[derive(Debug)]
pub struct Tls13ClientSessionValue {
suite: &'static Tls13CipherSuite,
age_add: u32,
max_early_data_size: u32,
pub common: ClientSessionCommon,
}
impl Tls13ClientSessionValue {
pub fn new(
suite: &'static Tls13CipherSuite,
ticket: Vec<u8>,
secret: Vec<u8>,
server_cert_chain: Vec<tls_core::key::Certificate>,
time_now: TimeBase,
lifetime_secs: u32,
age_add: u32,
max_early_data_size: u32,
) -> Self {
Self {
suite,
age_add,
max_early_data_size,
common: ClientSessionCommon::new(
ticket,
secret,
time_now,
lifetime_secs,
server_cert_chain,
),
}
}
/// [`Codec::read()`] with an extra `suite` argument.
///
/// We decode the `suite` argument separately because it allows us to
/// decide whether we're decoding an 1.2 or 1.3 session value.
pub fn read(suite: &'static Tls13CipherSuite, r: &mut Reader) -> Option<Self> {
Some(Self {
suite,
age_add: u32::read(r)?,
max_early_data_size: u32::read(r)?,
common: ClientSessionCommon::read(r)?,
})
}
/// Inherent implementation of the [`Codec::get_encoding()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn get_encoding(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
self.suite.common.suite.encode(&mut bytes);
self.age_add.encode(&mut bytes);
self.max_early_data_size.encode(&mut bytes);
self.common.encode(&mut bytes);
bytes
}
pub fn max_early_data_size(&self) -> u32 {
self.max_early_data_size
}
pub fn suite(&self) -> &'static Tls13CipherSuite {
self.suite
}
}
impl std::ops::Deref for Tls13ClientSessionValue {
type Target = ClientSessionCommon;
fn deref(&self) -> &Self::Target {
&self.common
}
}
#[cfg(feature = "tls12")]
#[derive(Debug)]
pub struct Tls12ClientSessionValue {
suite: &'static Tls12CipherSuite,
pub session_id: SessionID,
extended_ms: bool,
pub common: ClientSessionCommon,
}
#[cfg(feature = "tls12")]
impl Tls12ClientSessionValue {
pub fn new(
suite: &'static Tls12CipherSuite,
session_id: SessionID,
ticket: Vec<u8>,
master_secret: Vec<u8>,
server_cert_chain: Vec<tls_core::key::Certificate>,
time_now: TimeBase,
lifetime_secs: u32,
extended_ms: bool,
) -> Self {
Self {
suite,
session_id,
extended_ms,
common: ClientSessionCommon::new(
ticket,
master_secret,
time_now,
lifetime_secs,
server_cert_chain,
),
}
}
/// [`Codec::read()`] with an extra `suite` argument.
///
/// We decode the `suite` argument separately because it allows us to
/// decide whether we're decoding an 1.2 or 1.3 session value.
fn read(suite: &'static Tls12CipherSuite, r: &mut Reader) -> Option<Self> {
Some(Self {
suite,
session_id: SessionID::read(r)?,
extended_ms: u8::read(r)? == 1,
common: ClientSessionCommon::read(r)?,
})
}
/// Inherent implementation of the [`Codec::get_encoding()`] method.
///
/// (See `read()` for why this is inherent here.)
pub fn get_encoding(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
self.suite.common.suite.encode(&mut bytes);
self.session_id.encode(&mut bytes);
(if self.extended_ms { 1u8 } else { 0u8 }).encode(&mut bytes);
self.common.encode(&mut bytes);
bytes
}
pub fn take_ticket(&mut self) -> Vec<u8> {
mem::take(&mut self.common.ticket.0)
}
pub fn extended_ms(&self) -> bool {
self.extended_ms
}
pub fn suite(&self) -> &'static Tls12CipherSuite {
self.suite
}
}
#[cfg(feature = "tls12")]
impl std::ops::Deref for Tls12ClientSessionValue {
type Target = ClientSessionCommon;
fn deref(&self) -> &Self::Target {
&self.common
}
}
#[derive(Debug)]
pub struct ClientSessionCommon {
ticket: PayloadU16,
secret: PayloadU8,
epoch: u64,
lifetime_secs: u32,
server_cert_chain: CertificatePayload,
}
impl ClientSessionCommon {
fn new(
ticket: Vec<u8>,
secret: Vec<u8>,
time_now: TimeBase,
lifetime_secs: u32,
server_cert_chain: Vec<tls_core::key::Certificate>,
) -> Self {
Self {
ticket: PayloadU16(ticket),
secret: PayloadU8(secret),
epoch: time_now.as_secs(),
lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
server_cert_chain,
}
}
/// [`Codec::read()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn read(r: &mut Reader) -> Option<Self> {
Some(Self {
ticket: PayloadU16::read(r)?,
secret: PayloadU8::read(r)?,
epoch: u64::read(r)?,
lifetime_secs: u32::read(r)?,
server_cert_chain: CertificatePayload::read(r)?,
})
}
/// [`Codec::encode()`] is inherent here to avoid leaking the [`Codec`]
/// implementation through [`Deref`] implementations on
/// [`Tls12ClientSessionValue`] and [`Tls13ClientSessionValue`].
fn encode(&self, bytes: &mut Vec<u8>) {
self.ticket.encode(bytes);
self.secret.encode(bytes);
self.epoch.encode(bytes);
self.lifetime_secs.encode(bytes);
self.server_cert_chain.encode(bytes);
}
pub fn server_cert_chain(&self) -> &[tls_core::key::Certificate] {
self.server_cert_chain.as_ref()
}
pub fn secret(&self) -> &[u8] {
self.secret.0.as_ref()
}
pub fn ticket(&self) -> &[u8] {
self.ticket.0.as_ref()
}
/// Test only: wind back epoch by delta seconds.
pub fn rewind_epoch(&mut self, delta: u32) {
self.epoch -= delta as u64;
}
}
static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
/// This is the maximum allowed skew between server and client clocks, over
/// the maximum ticket lifetime period. This encompasses TCP retransmission
/// times in case packet loss occurs when the client sends the ClientHello
/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
// --- Server types ---
pub type ServerSessionKey = SessionID;
#[derive(Debug)]
pub struct ServerSessionValue {
pub sni: Option<webpki::DnsName>,
pub version: ProtocolVersion,
pub cipher_suite: CipherSuite,
pub master_secret: PayloadU8,
pub extended_ms: bool,
pub client_cert_chain: Option<CertificatePayload>,
pub alpn: Option<PayloadU8>,
pub application_data: PayloadU16,
pub creation_time_sec: u64,
pub age_obfuscation_offset: u32,
freshness: Option<bool>,
}
impl Codec for ServerSessionValue {
fn encode(&self, bytes: &mut Vec<u8>) {
if let Some(ref sni) = self.sni {
1u8.encode(bytes);
let sni_bytes: &str = sni.as_ref().into();
PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
} else {
0u8.encode(bytes);
}
self.version.encode(bytes);
self.cipher_suite.encode(bytes);
self.master_secret.encode(bytes);
(if self.extended_ms { 1u8 } else { 0u8 }).encode(bytes);
if let Some(ref chain) = self.client_cert_chain {
1u8.encode(bytes);
chain.encode(bytes);
} else {
0u8.encode(bytes);
}
if let Some(ref alpn) = self.alpn {
1u8.encode(bytes);
alpn.encode(bytes);
} else {
0u8.encode(bytes);
}
self.application_data.encode(bytes);
self.creation_time_sec.encode(bytes);
self.age_obfuscation_offset.encode(bytes);
}
fn read(r: &mut Reader) -> Option<Self> {
let has_sni = u8::read(r)?;
let sni = if has_sni == 1 {
let dns_name = PayloadU8::read(r)?;
let dns_name = webpki::DnsNameRef::try_from_ascii(&dns_name.0).ok()?;
Some(dns_name.into())
} else {
None
};
let v = ProtocolVersion::read(r)?;
let cs = CipherSuite::read(r)?;
let ms = PayloadU8::read(r)?;
let ems = u8::read(r)?;
let has_ccert = u8::read(r)? == 1;
let ccert = if has_ccert {
Some(CertificatePayload::read(r)?)
} else {
None
};
let has_alpn = u8::read(r)? == 1;
let alpn = if has_alpn {
Some(PayloadU8::read(r)?)
} else {
None
};
let application_data = PayloadU16::read(r)?;
let creation_time_sec = u64::read(r)?;
let age_obfuscation_offset = u32::read(r)?;
Some(Self {
sni,
version: v,
cipher_suite: cs,
master_secret: ms,
extended_ms: ems == 1u8,
client_cert_chain: ccert,
alpn,
application_data,
creation_time_sec,
age_obfuscation_offset,
freshness: None,
})
}
}
impl ServerSessionValue {
pub fn new(
sni: Option<&webpki::DnsName>,
v: ProtocolVersion,
cs: CipherSuite,
ms: Vec<u8>,
client_cert_chain: Option<CertificatePayload>,
alpn: Option<Vec<u8>>,
application_data: Vec<u8>,
creation_time: TimeBase,
age_obfuscation_offset: u32,
) -> Self {
Self {
sni: sni.cloned(),
version: v,
cipher_suite: cs,
master_secret: PayloadU8::new(ms),
extended_ms: false,
client_cert_chain,
alpn: alpn.map(PayloadU8::new),
application_data: PayloadU16::new(application_data),
creation_time_sec: creation_time.as_secs(),
age_obfuscation_offset,
freshness: None,
}
}
pub fn set_extended_ms_used(&mut self) {
self.extended_ms = true;
}
pub fn set_freshness(mut self, obfuscated_client_age_ms: u32, time_now: TimeBase) -> Self {
let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
let server_age_ms =
(time_now.as_secs().saturating_sub(self.creation_time_sec) as u32).saturating_mul(1000);
let age_difference = if client_age_ms < server_age_ms {
server_age_ms - client_age_ms
} else {
client_age_ms - server_age_ms
};
self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
self
}
pub fn is_fresh(&self) -> bool {
self.freshness.unwrap_or_default()
}
}

View File

@@ -1,78 +0,0 @@
use super::persist::*;
use crate::ticketer::TimeBase;
use std::convert::TryInto;
use tls_core::{
key::Certificate,
msgs::{
codec::{Codec, Reader},
enums::*,
},
suites::TLS13_AES_128_GCM_SHA256,
};
#[test]
fn clientsessionkey_is_debug() {
let name = "hello".try_into().unwrap();
let csk = ClientSessionKey::session_for_server_name(&name);
println!("{:?}", csk);
}
#[test]
fn clientsessionkey_cannot_be_read() {
let bytes = [0; 1];
let mut rd = Reader::init(&bytes);
assert!(ClientSessionKey::read(&mut rd).is_none());
}
#[test]
fn clientsessionvalue_is_debug() {
let csv = ClientSessionValue::from(Tls13ClientSessionValue::new(
TLS13_AES_128_GCM_SHA256.tls13().unwrap(),
vec![],
vec![1, 2, 3],
vec![Certificate(b"abc".to_vec()), Certificate(b"def".to_vec())],
TimeBase::now().unwrap(),
15,
10,
128,
));
println!("{:?}", csv);
}
#[test]
fn serversessionvalue_is_debug() {
let ssv = ServerSessionValue::new(
None,
ProtocolVersion::TLSv1_3,
CipherSuite::TLS13_AES_128_GCM_SHA256,
vec![1, 2, 3],
None,
None,
vec![4, 5, 6],
TimeBase::now().unwrap(),
0x12345678,
);
println!("{:?}", ssv);
}
#[test]
fn serversessionvalue_no_sni() {
let bytes = [
0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12,
0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
];
let mut rd = Reader::init(&bytes);
let ssv = ServerSessionValue::read(&mut rd).unwrap();
assert_eq!(ssv.get_encoding(), bytes);
}
#[test]
fn serversessionvalue_with_cert() {
let bytes = [
0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12,
0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
];
let mut rd = Reader::init(&bytes);
let ssv = ServerSessionValue::read(&mut rd).unwrap();
assert_eq!(ssv.get_encoding(), bytes);
}

View File

@@ -6,6 +6,7 @@ use ring::{
io::der,
signature::{self, EcdsaKeyPair, Ed25519KeyPair, RsaKeyPair},
};
use rustls_pki_types as pki_types;
use std::{convert::TryFrom, error::Error as StdError, fmt, sync::Arc};
use tls_core::{
key,
@@ -71,51 +72,6 @@ impl CertifiedKey {
pub fn end_entity_cert(&self) -> Result<&key::Certificate, SignError> {
self.cert.first().ok_or(SignError(()))
}
/// Check the certificate chain for validity:
/// - it should be non-empty list
/// - the first certificate should be parsable as a x509v3,
/// - the first certificate should quote the given server name
/// (if provided)
///
/// These checks are not security-sensitive. They are the
/// *server* attempting to detect accidental misconfiguration.
pub(crate) fn cross_check_end_entity_cert(
&self,
name: Option<webpki::DnsNameRef>,
) -> Result<(), Error> {
// Always reject an empty certificate chain.
let end_entity_cert = self.end_entity_cert().map_err(|SignError(())| {
Error::General("No end-entity certificate in certificate chain".to_string())
})?;
// Reject syntactically-invalid end-entity certificates.
let end_entity_cert =
webpki::EndEntityCert::try_from(end_entity_cert.as_ref()).map_err(|_| {
Error::General(
"End-entity certificate in certificate \
chain is syntactically invalid"
.to_string(),
)
})?;
if let Some(name) = name {
// If SNI was offered then the certificate must be valid for
// that hostname. Note that this doesn't fully validate that the
// certificate is valid; it only validates that the name is one
// that the certificate is valid for, if the certificate is
// valid.
if end_entity_cert.verify_is_valid_for_dns_name(name).is_err() {
return Err(Error::General(
"The server certificate is not \
valid for the given name"
.to_string(),
));
}
}
Ok(())
}
}
/// Parse `der` as any supported key encoding/type, returning

View File

@@ -1,223 +0,0 @@
// This program does benchmarking of the functions in verify.rs,
// that do certificate chain validation and signature verification.
//
// Note: we don't use any of the standard 'cargo bench', 'test::Bencher',
// etc. because it's unstable at the time of writing.
use crate::{anchors, verify, verify::ServerCertVerifier, OwnedTrustAnchor};
use std::convert::TryInto;
use web_time::{Duration, Instant, SystemTime};
use webpki_roots;
fn duration_nanos(d: Duration) -> u64 {
((d.as_secs() as f64) * 1e9 + (d.subsec_nanos() as f64)) as u64
}
#[test]
fn test_reddit_cert() {
Context::new(
"reddit",
"reddit.com",
&[
include_bytes!("testdata/cert-reddit.0.der"),
include_bytes!("testdata/cert-reddit.1.der"),
],
)
.bench(100)
}
#[test]
fn test_github_cert() {
Context::new(
"github",
"github.com",
&[
include_bytes!("testdata/cert-github.0.der"),
include_bytes!("testdata/cert-github.1.der"),
],
)
.bench(100)
}
#[test]
fn test_arstechnica_cert() {
Context::new(
"arstechnica",
"arstechnica.com",
&[
include_bytes!("testdata/cert-arstechnica.0.der"),
include_bytes!("testdata/cert-arstechnica.1.der"),
include_bytes!("testdata/cert-arstechnica.2.der"),
include_bytes!("testdata/cert-arstechnica.3.der"),
],
)
.bench(100)
}
#[test]
fn test_twitter_cert() {
Context::new(
"twitter",
"twitter.com",
&[
include_bytes!("testdata/cert-twitter.0.der"),
include_bytes!("testdata/cert-twitter.1.der"),
],
)
.bench(100)
}
#[test]
fn test_wikipedia_cert() {
Context::new(
"wikipedia",
"wikipedia.org",
&[
include_bytes!("testdata/cert-wikipedia.0.der"),
include_bytes!("testdata/cert-wikipedia.1.der"),
],
)
.bench(100)
}
#[test]
fn test_google_cert() {
Context::new(
"google",
"www.google.com",
&[
include_bytes!("testdata/cert-google.0.der"),
include_bytes!("testdata/cert-google.1.der"),
],
)
.bench(100)
}
#[test]
fn test_hn_cert() {
Context::new(
"hn",
"news.ycombinator.com",
&[
include_bytes!("testdata/cert-hn.0.der"),
include_bytes!("testdata/cert-hn.1.der"),
],
)
.bench(100)
}
#[test]
fn test_stackoverflow_cert() {
Context::new(
"stackoverflow",
"stackoverflow.com",
&[
include_bytes!("testdata/cert-stackoverflow.0.der"),
include_bytes!("testdata/cert-stackoverflow.1.der"),
],
)
.bench(100)
}
#[test]
fn test_duckduckgo_cert() {
Context::new(
"duckduckgo",
"duckduckgo.com",
&[
include_bytes!("testdata/cert-duckduckgo.0.der"),
include_bytes!("testdata/cert-duckduckgo.1.der"),
],
)
.bench(100)
}
#[test]
fn test_rustlang_cert() {
Context::new(
"rustlang",
"www.rust-lang.org",
&[
include_bytes!("testdata/cert-rustlang.0.der"),
include_bytes!("testdata/cert-rustlang.1.der"),
include_bytes!("testdata/cert-rustlang.2.der"),
],
)
.bench(100)
}
#[test]
fn test_wapo_cert() {
Context::new(
"wapo",
"www.washingtonpost.com",
&[
include_bytes!("testdata/cert-wapo.0.der"),
include_bytes!("testdata/cert-wapo.1.der"),
],
)
.bench(100)
}
struct Context {
name: &'static str,
domain: &'static str,
roots: anchors::RootCertStore,
chain: Vec<tls_core::key::Certificate>,
now: SystemTime,
}
impl Context {
fn new(name: &'static str, domain: &'static str, certs: &[&'static [u8]]) -> Self {
let mut roots = anchors::RootCertStore::empty();
roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
Self {
name,
domain,
roots,
chain: certs
.iter()
.copied()
.map(|bytes| tls_core::key::Certificate(bytes.to_vec()))
.collect(),
now: SystemTime::UNIX_EPOCH + Duration::from_secs(1640870720),
}
}
fn bench(&self, count: usize) {
let verifier = verify::WebPkiVerifier::new(self.roots.clone(), None);
const SCTS: &[&[u8]] = &[];
const OCSP_RESPONSE: &[u8] = &[];
let mut times = Vec::new();
let (end_entity, intermediates) = self.chain.split_first().unwrap();
for _ in 0..count {
let start = Instant::now();
let server_name = self.domain.try_into().unwrap();
verifier
.verify_server_cert(
end_entity,
intermediates,
&server_name,
&mut SCTS.iter().copied(),
OCSP_RESPONSE,
self.now,
)
.unwrap();
times.push(duration_nanos(Instant::now().duration_since(start)));
}
println!(
"verify_server_cert({}): min {:?}us",
self.name,
times.iter().min().unwrap() / 1000
);
}
}

View File

@@ -780,14 +780,12 @@ async fn client_checks_server_certificate_with_given_name() {
let mut server = ServerConnection::new(Arc::clone(&server_config)).unwrap();
let err = do_handshake_until_error(&mut client, &mut server).await;
assert_eq!(
assert!(matches!(
err,
Err(ErrorFromPeer::Client(Error::CoreError(
tls_core::Error::InvalidCertificateData(
"invalid peer certificate: CertNotValidForName".into(),
)
tls_core::Error::InvalidCertificateData(_)
)))
);
));
}
}
}
@@ -889,6 +887,7 @@ async fn client_error_is_sticky() {
#[tokio::test]
#[allow(clippy::no_effect)]
#[allow(clippy::unnecessary_operation)]
async fn client_is_send() {
let (client, _) = make_pair(KeyType::Rsa).await;
&client as &dyn Send;

View File

@@ -2,6 +2,7 @@
use futures::{AsyncRead, AsyncWrite};
use rustls::{server::AllowAnyAuthenticatedClient, ServerConfig, ServerConnection};
use rustls_pki_types::CertificateDer;
use std::{
convert::{TryFrom, TryInto},
io,
@@ -15,6 +16,7 @@ use tls_client::{
Certificate, ClientConfig, ClientConnection, Error, PrivateKey, RootCertStore,
RustCryptoBackend,
};
use webpki::anchor_from_trusted_cert;
macro_rules! embed_files {
(
@@ -409,9 +411,17 @@ pub fn finish_client_config(
kt: KeyType,
config: tls_client::ConfigBuilder<tls_client::WantsVerifier>,
) -> ClientConfig {
let mut root_store = RootCertStore::empty();
let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
let roots = rustls_pemfile::certs(&mut rootbuf)
.unwrap()
.into_iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert);
anchor_from_trusted_cert(&der).unwrap().to_owned()
})
.collect();
let root_store = RootCertStore { roots };
config
.with_root_certificates(root_store)
@@ -422,9 +432,17 @@ pub fn finish_client_config_with_creds(
kt: KeyType,
config: tls_client::ConfigBuilder<tls_client::WantsVerifier>,
) -> ClientConfig {
let mut root_store = RootCertStore::empty();
let mut rootbuf = io::BufReader::new(kt.bytes_for("ca.cert"));
root_store.add_parsable_certificates(&rustls_pemfile::certs(&mut rootbuf).unwrap());
let roots = rustls_pemfile::certs(&mut rootbuf)
.unwrap()
.into_iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert);
anchor_from_trusted_cert(&der).unwrap().to_owned()
})
.collect();
let root_store = RootCertStore { roots };
config
.with_root_certificates(root_store)

View File

@@ -35,4 +35,5 @@ sha2 = { workspace = true, optional = true }
thiserror = { workspace = true }
tracing = { workspace = true, optional = true }
web-time = { workspace = true }
webpki = { workspace = true, features = ["alloc", "std"] }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }

View File

@@ -1,65 +1,16 @@
use rustls_pki_types::TrustAnchor;
use crate::{
msgs::handshake::{DistinguishedName, DistinguishedNames},
x509,
};
/// A trust anchor, commonly known as a "Root Certificate."
#[derive(Debug, Clone)]
pub struct OwnedTrustAnchor {
subject: Vec<u8>,
spki: Vec<u8>,
name_constraints: Option<Vec<u8>>,
}
impl OwnedTrustAnchor {
/// Get a `webpki::TrustAnchor` by borrowing the owned elements.
pub(crate) fn to_trust_anchor(&self) -> webpki::TrustAnchor {
webpki::TrustAnchor {
subject: &self.subject,
spki: &self.spki,
name_constraints: self.name_constraints.as_deref(),
}
}
/// Constructs an `OwnedTrustAnchor` from its components.
///
/// `subject` is the subject field of the trust anchor.
///
/// `spki` is the `subjectPublicKeyInfo` field of the trust anchor.
///
/// `name_constraints` is the value of a DER-encoded name constraints to
/// apply for this trust anchor, if any.
pub fn from_subject_spki_name_constraints(
subject: impl Into<Vec<u8>>,
spki: impl Into<Vec<u8>>,
name_constraints: Option<impl Into<Vec<u8>>>,
) -> Self {
Self {
subject: subject.into(),
spki: spki.into(),
name_constraints: name_constraints.map(|x| x.into()),
}
}
}
/// Errors that can occur during operations with RootCertStore
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum RootCertStoreError {
#[error(transparent)]
WebpkiError(#[from] webpki::Error),
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("Unexpected PEM certificate count. Expected 1 certificate, got {0}")]
PemCertUnexpectedCount(usize),
}
/// A container for root certificates able to provide a root-of-trust
/// for connection authentication.
#[derive(Debug, Clone)]
pub struct RootCertStore {
/// The list of roots.
pub roots: Vec<OwnedTrustAnchor>,
pub roots: Vec<TrustAnchor<'static>>,
}
impl RootCertStore {
@@ -91,100 +42,4 @@ impl RootCertStore {
r
}
/// Add a single DER-encoded certificate to the store.
pub fn add(&mut self, der: &crate::key::Certificate) -> Result<(), RootCertStoreError> {
let ta = webpki::TrustAnchor::try_from_cert_der(&der.0)?;
let ota = OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
);
self.roots.push(ota);
Ok(())
}
/// Adds a single PEM-encoded certificate to the store.
pub fn add_pem(&mut self, pem: &str) -> Result<(), RootCertStoreError> {
let mut certificates = rustls_pemfile::certs(&mut pem.as_bytes())?;
if certificates.len() != 1 {
return Err(RootCertStoreError::PemCertUnexpectedCount(
certificates.len(),
));
}
self.add(&crate::key::Certificate(certificates.remove(0)))?;
Ok(())
}
/// Adds all the given TrustAnchors `anchors`. This does not
/// fail.
pub fn add_server_trust_anchors(
&mut self,
trust_anchors: impl Iterator<Item = OwnedTrustAnchor>,
) {
self.roots.extend(trust_anchors)
}
/// Parse the given DER-encoded certificates and add all that can be parsed
/// in a best-effort fashion.
///
/// This is because large collections of root certificates often
/// include ancient or syntactically invalid certificates.
///
/// Returns the number of certificates added, and the number that were ignored.
pub fn add_parsable_certificates(&mut self, der_certs: &[Vec<u8>]) -> (usize, usize) {
let mut valid_count = 0;
let mut invalid_count = 0;
for der_cert in der_certs {
match self.add(&crate::key::Certificate(der_cert.clone())) {
Ok(_) => valid_count += 1,
Err(_err) => invalid_count += 1,
}
}
(valid_count, invalid_count)
}
}
#[cfg(test)]
mod tests {
use super::*;
const CA_PEM_CERT: &[u8] = include_bytes!("../testdata/cert-digicert.pem");
#[test]
fn test_add_pem_ok() {
let pem = std::str::from_utf8(CA_PEM_CERT).unwrap();
assert!(RootCertStore::empty().add_pem(pem).is_ok());
}
#[test]
fn test_add_pem_err_bad_cert() {
assert_eq!(
RootCertStore::empty()
.add_pem("bad pem")
.err()
.unwrap()
.to_string(),
"Unexpected PEM certificate count. Expected 1 certificate, got 0"
);
}
#[test]
fn test_add_pem_err_more_than_one_cert() {
let pem1 = std::str::from_utf8(CA_PEM_CERT).unwrap();
let pem2 = pem1;
assert_eq!(
RootCertStore::empty()
.add_pem((pem1.to_owned() + pem2).as_str())
.err()
.unwrap()
.to_string(),
"Unexpected PEM certificate count. Expected 1 certificate, got 2"
);
}
}

View File

@@ -1,6 +1,6 @@
use std::{error::Error as StdError, fmt};
use crate::verify;
use rustls_pki_types as pki_types;
/// Encodes ways a client can know the expected name of the server.
///
@@ -26,20 +26,16 @@ use crate::verify;
/// ```
#[non_exhaustive]
#[derive(Debug, PartialEq, Clone)]
pub enum ServerName {
/// The server is identified by a DNS name. The name
/// is sent in the TLS Server Name Indication (SNI)
/// extension.
DnsName(verify::DnsName),
}
pub struct ServerName(pub(crate) pki_types::ServerName<'static>);
impl ServerName {
/// Return the name that should go in the SNI extension.
/// If [`None`] is returned, the SNI extension is not included
/// in the handshake.
pub fn for_sni(&self) -> Option<webpki::DnsNameRef> {
match self {
Self::DnsName(dns_name) => Some(dns_name.0.as_ref()),
pub fn for_sni(&self) -> Option<pki_types::DnsName<'static>> {
match &self.0 {
pki_types::ServerName::DnsName(dns_name) => Some(dns_name.clone()),
_ => None,
}
}
@@ -49,13 +45,17 @@ impl ServerName {
DnsName = 0x01,
}
let Self::DnsName(dns_name) = self;
let bytes = dns_name.0.as_ref();
let bytes: &[u8] = match &self.0 {
pki_types::ServerName::DnsName(dns_name) => dns_name.as_ref().as_ref(),
pki_types::ServerName::IpAddress(pki_types::IpAddr::V4(ip)) => ip.as_ref(),
pki_types::ServerName::IpAddress(pki_types::IpAddr::V6(ip)) => ip.as_ref(),
_ => unreachable!(),
};
let mut r = Vec::with_capacity(2 + bytes.as_ref().len());
let mut r = Vec::with_capacity(2 + bytes.len());
r.push(UniqueTypeCode::DnsName as u8);
r.push(bytes.as_ref().len() as u8);
r.extend_from_slice(bytes.as_ref());
r.push(bytes.len() as u8);
r.extend_from_slice(bytes);
r
}
@@ -66,9 +66,9 @@ impl ServerName {
impl TryFrom<&str> for ServerName {
type Error = InvalidDnsNameError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match webpki::DnsNameRef::try_from_ascii_str(s) {
Ok(dns) => Ok(Self::DnsName(verify::DnsName(dns.into()))),
Err(webpki::InvalidDnsNameError) => Err(InvalidDnsNameError),
match pki_types::DnsName::try_from(s) {
Ok(dns) => Ok(Self(pki_types::ServerName::DnsName(dns.to_owned()))),
Err(_) => Err(InvalidDnsNameError),
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::msgs::enums::{AlertDescription, ContentType, HandshakeType};
use std::{error::Error as StdError, fmt, time::SystemTimeError};
use std::{error::Error as StdError, fmt};
use web_time::SystemTimeError;
/// rustls reports protocol errors using this type.
#[derive(Debug, PartialEq, Clone)]
@@ -41,8 +42,9 @@ pub enum Error {
/// We couldn't decrypt a message. This is invariably fatal.
DecryptError,
/// We couldn't encrypt a message because it was larger than the allowed message size.
/// This should never happen if the application is using valid record sizes.
/// We couldn't encrypt a message because it was larger than the allowed
/// message size. This should never happen if the application is using
/// valid record sizes.
EncryptError,
/// The peer doesn't support a protocol version/feature we require.

View File

@@ -7,7 +7,7 @@ pub struct Reader<'a> {
}
impl Reader<'_> {
pub fn init(bytes: &[u8]) -> Reader {
pub fn init(bytes: &[u8]) -> Reader<'_> {
Reader {
buf: bytes,
offs: 0,
@@ -42,7 +42,7 @@ impl Reader<'_> {
self.offs
}
pub fn sub(&mut self, len: usize) -> Option<Reader> {
pub fn sub(&mut self, len: usize) -> Option<Reader<'_>> {
self.take(len).map(Reader::init)
}
}

View File

@@ -1,3 +1,5 @@
use rustls_pki_types as pki_types;
use crate::{
key,
msgs::{
@@ -249,14 +251,14 @@ impl DecomposedSignatureScheme for SignatureScheme {
#[derive(Clone, Debug)]
pub enum ServerNamePayload {
// Stored twice, bytes so we can round-trip, and DnsName for use
HostName((PayloadU16, webpki::DnsName)),
HostName((PayloadU16, pki_types::DnsName<'static>)),
Unknown(Payload),
}
impl ServerNamePayload {
pub fn new_hostname(hostname: webpki::DnsName) -> Self {
pub fn new_hostname(hostname: pki_types::DnsName<'static>) -> Self {
let raw = {
let s: &str = hostname.as_ref().into();
let s: &str = hostname.as_ref();
PayloadU16::new(s.as_bytes().into())
};
Self::HostName((raw, hostname))
@@ -266,8 +268,8 @@ impl ServerNamePayload {
let raw = PayloadU16::read(r)?;
let dns_name = {
match webpki::DnsNameRef::try_from_ascii(&raw.0) {
Ok(dns_name) => dns_name.into(),
match pki_types::DnsName::try_from(raw.0.as_slice()) {
Ok(dns_name) => dns_name.to_owned(),
Err(_) => {
warn!("Illegal SNI hostname received {:?}", raw.0);
return None;
@@ -313,11 +315,12 @@ declare_u16_vec!(ServerNameRequest, ServerName);
pub trait ConvertServerNameList {
fn has_duplicate_names_for_type(&self) -> bool;
fn get_single_hostname(&self) -> Option<webpki::DnsNameRef>;
fn get_single_hostname(&self) -> Option<pki_types::DnsName<'static>>;
}
impl ConvertServerNameList for ServerNameRequest {
/// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
/// RFC6066: "The ServerNameList MUST NOT contain more than one name of the
/// same name_type."
fn has_duplicate_names_for_type(&self) -> bool {
let mut seen = collections::HashSet::new();
@@ -330,10 +333,10 @@ impl ConvertServerNameList for ServerNameRequest {
false
}
fn get_single_hostname(&self) -> Option<webpki::DnsNameRef> {
fn only_dns_hostnames(name: &ServerName) -> Option<webpki::DnsNameRef> {
fn get_single_hostname(&self) -> Option<pki_types::DnsName<'static>> {
fn only_dns_hostnames(name: &ServerName) -> Option<pki_types::DnsName<'static>> {
if let ServerNamePayload::HostName((_, ref dns)) = name.payload {
Some(dns.as_ref())
Some(dns.clone())
} else {
None
}
@@ -694,16 +697,16 @@ impl Codec for ClientExtension {
}
}
fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::DnsName {
let dns_name_str: &str = dns_name.into();
fn trim_hostname_trailing_dot_for_sni(
dns_name: pki_types::DnsName<'static>,
) -> pki_types::DnsName<'static> {
let dns_name_str: &str = dns_name.as_ref();
// RFC6066: "The hostname is represented as a byte string using
// ASCII encoding without a trailing dot"
if dns_name_str.ends_with('.') {
let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
webpki::DnsNameRef::try_from_ascii_str(trimmed)
.unwrap()
.to_owned()
pki_types::DnsName::try_from(trimmed).unwrap().to_owned()
} else {
dns_name.to_owned()
}
@@ -711,7 +714,7 @@ fn trim_hostname_trailing_dot_for_sni(dns_name: webpki::DnsNameRef) -> webpki::D
impl ClientExtension {
/// Make a basic SNI ServerNameRequest quoting `hostname`.
pub fn make_sni(dns_name: webpki::DnsNameRef) -> Self {
pub fn make_sni(dns_name: pki_types::DnsName<'static>) -> Self {
let name = ServerName {
typ: ServerNameType::HostName,
payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),

View File

@@ -1,3 +1,5 @@
use rustls_pki_types as pki_types;
use super::{
base::{Payload, PayloadU16, PayloadU24, PayloadU8},
codec::{put_u16, Codec, Reader},
@@ -5,7 +7,6 @@ use super::{
handshake::*,
};
use crate::key::Certificate;
use webpki::DnsNameRef;
#[test]
fn rejects_short_random() {
@@ -186,7 +187,8 @@ fn can_roundtrip_multiname_sni() {
assert!(req.has_duplicate_names_for_type());
let dns_name_str: &str = req.get_single_hostname().unwrap().into();
let dns_name = req.get_single_hostname().unwrap();
let dns_name_str: &str = dns_name.as_ref();
assert_eq!(dns_name_str, "hi");
assert_eq!(req[0].typ, ServerNameType::HostName);
@@ -363,7 +365,7 @@ fn get_sample_clienthellopayload() -> ClientHelloPayload {
ClientExtension::ECPointFormats(ECPointFormatList::supported()),
ClientExtension::NamedGroups(vec![NamedGroup::X25519]),
ClientExtension::SignatureAlgorithms(vec![SignatureScheme::ECDSA_NISTP256_SHA256]),
ClientExtension::make_sni(DnsNameRef::try_from_ascii_str("hello").unwrap()),
ClientExtension::make_sni(pki_types::DnsName::try_from("hello").unwrap().to_owned()),
ClientExtension::SessionTicket(ClientSessionTicket::Request),
ClientExtension::SessionTicket(ClientSessionTicket::Offer(Payload(vec![]))),
ClientExtension::Protocols(vec![PayloadU8(vec![0])]),

View File

@@ -1,5 +1,5 @@
use crate::{
anchors::{OwnedTrustAnchor, RootCertStore},
anchors::RootCertStore,
dns::ServerName,
error::Error,
key::Certificate,
@@ -9,27 +9,10 @@ use crate::{
},
};
use ring::digest::Digest;
use std::convert::TryFrom;
use web_time::SystemTime;
use rustls_pki_types as pki_types;
use web_time::{SystemTime, UNIX_EPOCH};
type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm];
/// Which signature verification mechanisms we support. No particular
/// order.
static SUPPORTED_SIG_ALGS: SignatureAlgorithms = &[
&webpki::ECDSA_P256_SHA256,
&webpki::ECDSA_P256_SHA384,
&webpki::ECDSA_P384_SHA256,
&webpki::ECDSA_P384_SHA384,
&webpki::ED25519,
&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
&webpki::RSA_PKCS1_2048_8192_SHA256,
&webpki::RSA_PKCS1_2048_8192_SHA384,
&webpki::RSA_PKCS1_2048_8192_SHA512,
&webpki::RSA_PKCS1_3072_8192_SHA384,
];
type SignatureAlgorithms = &'static [&'static dyn pki_types::SignatureVerificationAlgorithm];
// Marker types. These are used to bind the fact some verification
// (certificate chain or handshake signature) has taken place into
@@ -170,7 +153,7 @@ pub trait ServerCertVerifier: Send + Sync {
/// A type which encapsuates a string that is a syntactically valid DNS name.
#[derive(Clone, Debug, PartialEq)]
pub struct DnsName(pub(crate) webpki::DnsName);
pub struct DnsName(pub(crate) pki_types::DnsName<'static>);
impl AsRef<str> for DnsName {
fn as_ref(&self) -> &str {
@@ -289,35 +272,33 @@ impl ServerCertVerifier for WebPkiVerifier {
_ocsp_response: &[u8],
now: SystemTime,
) -> Result<ServerCertVerified, Error> {
let (cert, chain, trustroots) = prepare(end_entity, intermediates, &self.roots)?;
// `webpki::Time::try_from` does not work with `web_time::SystemTime`.
// To workaround this we convert `SystemTime` to seconds and use
// `webpki::Time::from_seconds_since_unix_epoch` instead.
let duration_since_epoch = now
.duration_since(web_time::UNIX_EPOCH)
.map_err(|_| Error::FailedToGetCurrentTime)?;
let seconds_since_unix_epoch = duration_since_epoch.as_secs();
let webpki_now = webpki::Time::from_seconds_since_unix_epoch(seconds_since_unix_epoch);
let cert = pki_types::CertificateDer::from(end_entity.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
let intermediates = intermediates
.iter()
.map(|c| pki_types::CertificateDer::from(c.0.as_slice()))
.collect::<Vec<_>>();
let time = pki_types::UnixTime::since_unix_epoch(now.duration_since(UNIX_EPOCH)?);
let ServerName::DnsName(dns_name) = server_name;
let cert = cert
.verify_is_valid_tls_server_cert(
SUPPORTED_SIG_ALGS,
&webpki::TlsServerTrustAnchors(&trustroots),
&chain,
webpki_now,
)
.map_err(pki_error)
.map(|_| cert)?;
cert.verify_for_usage(
webpki::ALL_VERIFICATION_ALGS,
&self.roots.roots,
&intermediates,
time,
webpki::KeyUsage::server_auth(),
None,
None,
)
.map(|_| ())
.map_err(pki_error)?;
if let Some(policy) = &self.ct_policy {
policy.verify(end_entity, now, scts)?;
}
cert.verify_is_valid_for_dns_name(dns_name.0.as_ref())
.map_err(pki_error)
cert.verify_is_valid_for_subject_name(&server_name.0)
.map(|_| ServerCertVerified::assertion())
.map_err(pki_error)
}
}
@@ -429,57 +410,37 @@ impl CertificateTransparencyPolicy {
}
}
type CertChainAndRoots<'a, 'b> = (
webpki::EndEntityCert<'a>,
Vec<&'a [u8]>,
Vec<webpki::TrustAnchor<'b>>,
);
fn prepare<'a, 'b>(
end_entity: &'a Certificate,
intermediates: &'a [Certificate],
roots: &'b RootCertStore,
) -> Result<CertChainAndRoots<'a, 'b>, Error> {
// EE cert must appear first.
let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?;
let intermediates: Vec<&'a [u8]> = intermediates.iter().map(|cert| cert.0.as_ref()).collect();
let trustroots: Vec<webpki::TrustAnchor> = roots
.roots
.iter()
.map(OwnedTrustAnchor::to_trust_anchor)
.collect();
Ok((cert, intermediates, trustroots))
}
pub(crate) fn pki_error(error: webpki::Error) -> Error {
use webpki::Error::*;
match error {
BadDer | BadDerTime => Error::InvalidCertificateEncoding,
InvalidSignatureForPublicKey => Error::InvalidCertificateSignature,
UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => {
UnsupportedSignatureAlgorithmContext(_)
| UnsupportedSignatureAlgorithmForPublicKeyContext(_) => {
Error::InvalidCertificateSignatureType
}
e => Error::InvalidCertificateData(format!("invalid peer certificate: {e}")),
}
}
static ECDSA_SHA256: SignatureAlgorithms =
&[&webpki::ECDSA_P256_SHA256, &webpki::ECDSA_P384_SHA256];
static ECDSA_SHA256: SignatureAlgorithms = &[
webpki::ring::ECDSA_P256_SHA256,
webpki::ring::ECDSA_P384_SHA256,
];
static ECDSA_SHA384: SignatureAlgorithms =
&[&webpki::ECDSA_P256_SHA384, &webpki::ECDSA_P384_SHA384];
static ECDSA_SHA384: SignatureAlgorithms = &[
webpki::ring::ECDSA_P256_SHA384,
webpki::ring::ECDSA_P384_SHA384,
];
static ED25519: SignatureAlgorithms = &[&webpki::ED25519];
static ED25519: SignatureAlgorithms = &[webpki::ring::ED25519];
static RSA_SHA256: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA256];
static RSA_SHA384: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA384];
static RSA_SHA512: SignatureAlgorithms = &[&webpki::RSA_PKCS1_2048_8192_SHA512];
static RSA_PSS_SHA256: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY];
static RSA_PSS_SHA384: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY];
static RSA_PSS_SHA512: SignatureAlgorithms = &[&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY];
static RSA_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA256];
static RSA_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA384];
static RSA_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PKCS1_2048_8192_SHA512];
static RSA_PSS_SHA256: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY];
static RSA_PSS_SHA384: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY];
static RSA_PSS_SHA512: SignatureAlgorithms = &[webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY];
fn convert_scheme(scheme: SignatureScheme) -> Result<SignatureAlgorithms, Error> {
match scheme {
@@ -514,13 +475,18 @@ fn verify_sig_using_any_alg(
// webpki::SignatureAlgorithm. Therefore, convert_algs maps to several and
// we try them all.
for alg in algs {
match cert.verify_signature(alg, message, sig) {
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey) => continue,
match cert.verify_signature(*alg, message, sig) {
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKeyContext(_)) => continue,
res => return res,
}
}
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey)
Err(webpki::Error::UnsupportedSignatureAlgorithmContext(
webpki::UnsupportedSignatureAlgorithmContext {
signature_algorithm_id: vec![],
supported_algorithms: algs.iter().map(|alg| alg.signature_alg_id()).collect(),
},
))
}
fn verify_signed_struct(
@@ -529,7 +495,9 @@ fn verify_signed_struct(
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
let possible_algs = convert_scheme(dss.scheme)?;
let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?;
let cert = pki_types::CertificateDer::from(cert.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
verify_sig_using_any_alg(&cert, possible_algs, message, &dss.sig.0)
.map_err(pki_error)
@@ -538,16 +506,16 @@ fn verify_signed_struct(
fn convert_alg_tls13(
scheme: SignatureScheme,
) -> Result<&'static webpki::SignatureAlgorithm, Error> {
) -> Result<&'static dyn pki_types::SignatureVerificationAlgorithm, Error> {
use crate::msgs::enums::SignatureScheme::*;
match scheme {
ECDSA_NISTP256_SHA256 => Ok(&webpki::ECDSA_P256_SHA256),
ECDSA_NISTP384_SHA384 => Ok(&webpki::ECDSA_P384_SHA384),
ED25519 => Ok(&webpki::ED25519),
RSA_PSS_SHA256 => Ok(&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY),
RSA_PSS_SHA384 => Ok(&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY),
RSA_PSS_SHA512 => Ok(&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY),
ECDSA_NISTP256_SHA256 => Ok(webpki::ring::ECDSA_P256_SHA256),
ECDSA_NISTP384_SHA384 => Ok(webpki::ring::ECDSA_P384_SHA384),
ED25519 => Ok(webpki::ring::ED25519),
RSA_PSS_SHA256 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY),
RSA_PSS_SHA384 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY),
RSA_PSS_SHA512 => Ok(webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY),
_ => {
let error_msg = format!("received unsupported sig scheme {scheme:?}");
Err(Error::PeerMisbehavedError(error_msg))
@@ -583,7 +551,8 @@ fn verify_tls13(
) -> Result<HandshakeSignatureValid, Error> {
let alg = convert_alg_tls13(dss.scheme)?;
let cert = webpki::EndEntityCert::try_from(cert.0.as_ref()).map_err(pki_error)?;
let cert = pki_types::CertificateDer::from(cert.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
cert.verify_signature(alg, msg, &dss.sig.0)
.map_err(pki_error)

View File

@@ -12,8 +12,7 @@ workspace = true
[features]
default = ["rayon"]
rayon = ["mpz-common/rayon"]
force-st = ["mpz-common/force-st"]
rayon = ["mpz-zk/rayon", "mpz-garble/rayon"]
web = ["dep:web-spawn"]
[dependencies]
@@ -41,11 +40,15 @@ mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
aes = { workspace = true }
cipher-crypto = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }
rand = { workspace = true }
rustls-pki-types = "1.12.0"
rustls-pki-types = { workspace = true }
rustls-webpki = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
@@ -57,6 +60,8 @@ rangeset = { workspace = true }
webpki-roots = { workspace = true }
[dev-dependencies]
lipsum = { workspace = true }
sha2 = { workspace = true }
rstest = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
@@ -65,3 +70,5 @@ tokio-util = { workspace = true, features = ["compat"] }
hyper = { workspace = true, features = ["client"] }
http-body-util = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tlsn-core = { workspace = true, features = ["fixtures"] }
mpz-ot = { workspace = true, features = ["ideal"] }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,708 @@
//! Authentication of the transcript plaintext and creation of the transcript
//! references.
use std::ops::Range;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeError, DecodeFutureTyped, MemoryExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Disjoint, RangeSet, Union, UnionMut};
use tlsn_core::{
hash::HashAlgId,
transcript::{ContentType, Direction, PartialTranscript, Record, TlsTranscript},
};
use crate::{
Role,
commit::transcript::TranscriptRefs,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Transcript Authenticator.
pub(crate) struct Authenticator {
encoding: Index,
hash: Index,
decoding: Index,
proving: Index,
}
impl Authenticator {
/// Creates a new authenticator.
///
/// # Arguments
///
/// * `encoding` - Ranges for encoding commitments.
/// * `hash` - Ranges for hash commitments.
/// * `partial` - The partial transcript.
pub(crate) fn new<'a>(
encoding: impl Iterator<Item = &'a (Direction, RangeSet<usize>)>,
hash: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
partial: Option<&PartialTranscript>,
) -> Self {
// Compute encoding index.
let mut encoding_sent = RangeSet::default();
let mut encoding_recv = RangeSet::default();
for (d, idx) in encoding {
match d {
Direction::Sent => encoding_sent.union_mut(idx),
Direction::Received => encoding_recv.union_mut(idx),
}
}
let encoding = Index::new(encoding_sent, encoding_recv);
// Compute hash index.
let mut hash_sent = RangeSet::default();
let mut hash_recv = RangeSet::default();
for (d, idx, _) in hash {
match d {
Direction::Sent => hash_sent.union_mut(idx),
Direction::Received => hash_recv.union_mut(idx),
}
}
let hash = Index {
sent: hash_sent,
recv: hash_recv,
};
// Compute decoding index.
let mut decoding_sent = RangeSet::default();
let mut decoding_recv = RangeSet::default();
if let Some(partial) = partial {
decoding_sent.union_mut(partial.sent_authed());
decoding_recv.union_mut(partial.received_authed());
}
let decoding = Index::new(decoding_sent, decoding_recv);
// Compute proving index.
let mut proving_sent = RangeSet::default();
let mut proving_recv = RangeSet::default();
proving_sent.union_mut(decoding.sent());
proving_sent.union_mut(encoding.sent());
proving_sent.union_mut(hash.sent());
proving_recv.union_mut(decoding.recv());
proving_recv.union_mut(encoding.recv());
proving_recv.union_mut(hash.recv());
let proving = Index::new(proving_sent, proving_recv);
Self {
encoding,
hash,
decoding,
proving,
}
}
/// Authenticates the sent plaintext, returning a proof of encryption and
/// writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_sent` - ZK AES Cipher for sent traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_sent(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_sent: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let missing_index = transcript_refs.compute_missing(Direction::Sent, self.proving.sent());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_sent,
Direction::Sent,
sent,
transcript_refs,
missing_index,
)
}
/// Authenticates the received plaintext, returning a proof of encryption
/// and writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_recv` - ZK AES Cipher for received traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_recv(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_recv: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let decoding_recv = self.decoding.recv();
let fully_decoded = decoding_recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
// If we only have decoding ranges, and the parts we are going to decode will
// complete to the full received transcript, then we do not need to
// authenticate, because this will be done by
// `crate::commit::decode::verify_transcript`, as it uses the server write
// key and iv for verification.
if decoding_recv == self.proving.recv() && fully_decoded == full_range {
return Ok(RecordProof::default());
}
let missing_index =
transcript_refs.compute_missing(Direction::Received, self.proving.recv());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_recv,
Direction::Received,
recv,
transcript_refs,
missing_index,
)
}
/// Returns the sent and received encoding ranges.
pub(crate) fn encoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.encoding.sent(), self.encoding.recv())
}
/// Returns the sent and received hash ranges.
pub(crate) fn hash(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.hash.sent(), self.hash.recv())
}
/// Returns the sent and received decoding ranges.
pub(crate) fn decoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.decoding.sent(), self.decoding.recv())
}
}
/// Authenticates parts of the transcript in zk.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes` - ZK AES Cipher.
/// * `direction` - The direction of the application data.
/// * `app_data` - The application data.
/// * `transcript_refs` - The transcript references.
/// * `missing_index` - The index which needs to be proven.
fn authenticate<'a>(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
direction: Direction,
app_data: impl Iterator<Item = &'a Record>,
transcript_refs: &mut TranscriptRefs,
missing_index: RangeSet<usize>,
) -> Result<RecordProof, AuthError> {
let mut record_idx = Range::default();
let mut ciphertexts = Vec::new();
for record in app_data {
let record_len = record.ciphertext.len();
record_idx.end += record_len;
if missing_index.is_disjoint(&record_idx) {
record_idx.start += record_len;
continue;
}
let (plaintext_ref, ciphertext_ref) =
zk_aes.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())?;
if let Role::Prover = zk_aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(AuthError(ErrorRepr::MissingPlainText));
};
vm.assign(plaintext_ref, plaintext).map_err(AuthError::vm)?;
}
vm.commit(plaintext_ref).map_err(AuthError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(AuthError::vm)?;
transcript_refs.add(direction, &record_idx, plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
record_idx.start += record_len;
}
let proof = RecordProof { ciphertexts };
Ok(proof)
}
#[derive(Debug, Clone, Default)]
struct Index {
sent: RangeSet<usize>,
recv: RangeSet<usize>,
}
impl Index {
fn new(sent: RangeSet<usize>, recv: RangeSet<usize>) -> Self {
Self { sent, recv }
}
fn sent(&self) -> &RangeSet<usize> {
&self.sent
}
fn recv(&self) -> &RangeSet<usize> {
&self.recv
}
}
/// Proof of encryption.
#[derive(Debug, Default)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), AuthError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(AuthError::vm)?
.ok_or(AuthError(ErrorRepr::MissingDecoding))?;
if ciphertext != expected {
return Err(AuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}
/// Error for [`Authenticator`].
#[derive(Debug, thiserror::Error)]
#[error("transcript authentication error: {0}")]
pub(crate) struct AuthError(#[source] ErrorRepr);
impl AuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk-aes error: {0}")]
ZkAes(ZkAesCtrError),
#[error("decode error: {0}")]
Decode(DecodeError),
#[error("plaintext is missing in record")]
MissingPlainText,
#[error("decoded value is missing")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
impl From<ZkAesCtrError> for AuthError {
fn from(value: ZkAesCtrError) -> Self {
Self(ErrorRepr::ZkAes(value))
}
}
impl From<DecodeError> for AuthError {
fn from(value: DecodeError) -> Self {
Self(ErrorRepr::Decode(value))
}
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
auth::{Authenticator, ErrorRepr},
transcript::TranscriptRefs,
},
zk_aes_ctr::ZkAesCtr,
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY, RECORD_SIZE},
hash::HashAlgId,
transcript::{ContentType, Direction, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_authenticator_sent(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_sent(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(600..1600));
prove_range.union_mut(&(800..2000));
prove_range.union_mut(&(2600..3700));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Sent), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Sent), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_recv(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, RECV_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, RECV_LEN).unwrap();
let _ = auth_prover
.auth_recv(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_recv(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(4000..4200));
prove_range.union_mut(&(5000..5800));
prove_range.union_mut(&(6800..RECV_LEN));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Received), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Received), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_sent_verify_fail(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
// Forge verifier transcript to check if verify fails.
// Use an index which is part of the proving range.
let forged = forged();
let proof = auth_verifier
.auth_sent(&mut verifier, &mut zk_verifier, &forged, &mut refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let err = proof.verify().unwrap_err();
assert!(matches!(err.0, ErrorRepr::InvalidCiphertext));
}
fn keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let sent = 600..1600;
let recv = 4000..4200;
(sent.into(), recv.into())
}
#[fixture]
fn encoding() -> Vec<(Direction, RangeSet<usize>)> {
let sent = 800..2000;
let recv = 5000..5800;
let encoding = vec![
(Direction::Sent, sent.into()),
(Direction::Received, recv.into()),
];
encoding
}
#[fixture]
fn hashes() -> Vec<(Direction, RangeSet<usize>, HashAlgId)> {
let sent = 2600..3700;
let recv = 6800..RECV_LEN;
let alg = HashAlgId::SHA256;
let hashes = vec![
(Direction::Sent, sent.into(), alg),
(Direction::Received, recv.into(), alg),
];
hashes
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 610;
let mut sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
sent[WRONG_BYTE_INDEX] = sent[WRONG_BYTE_INDEX].wrapping_add(1);
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
TranscriptRefs::new(sent_len, recv_len)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,615 @@
//! Selective disclosure.
use mpz_memory_core::{
Array, MemoryExt,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{Intersection, RangeSet, Subset, Union};
use tlsn_core::transcript::{ContentType, Direction, PartialTranscript, TlsTranscript};
use crate::commit::TranscriptRefs;
/// Decodes parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `transcript_refs` - The transcript references.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
transcript_refs: &mut TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding_ranges;
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for slice in sent_refs.into_iter() {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Sent, sent);
// If possible use server write key for decoding.
let fully_decoded = recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
if fully_decoded == full_range {
// Drop the future, we don't need it.
drop(vm.decode(key).map_err(DecodeError::vm)?);
drop(vm.decode(iv).map_err(DecodeError::vm)?);
transcript_refs.mark_decoded(Direction::Received, &full_range.into());
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for slice in recv_refs {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Received, recv);
}
Ok(())
}
/// Verifies parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `partial` - The partial transcript.
/// * `transcript_refs` - The transcript references.
/// * `transcript` - The TLS transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
partial: Option<&PartialTranscript>,
transcript_refs: &mut TranscriptRefs,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let (sent, recv) = decoding_ranges;
let mut authenticated_data = Vec::new();
// Add sent transcript parts.
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for data in sent_refs.into_iter() {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
// Add received transcript parts, if possible using key and iv.
if let (Some(key), Some(iv)) = (
vm.get(key).map_err(DecodeError::vm)?,
vm.get(iv).map_err(DecodeError::vm)?,
) {
let plaintext = verify_with_keys(key, iv, recv, transcript)?;
authenticated_data.extend_from_slice(&plaintext);
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for data in recv_refs {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in sent.iter_ranges() {
purported_data.extend_from_slice(&partial.sent_unsafe()[range]);
}
for range in recv.iter_ranges() {
purported_data.extend_from_slice(&partial.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(DecodeError(ErrorRepr::InconsistentTranscript));
}
Ok(())
}
/// Checks the transcript length.
///
/// # Arguments
///
/// * `partial` - The partial transcript.
/// * `transcript` - The TLS transcript.
pub(crate) fn check_transcript_length(
partial: Option<&PartialTranscript>,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let sent_len: usize = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len: usize = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
// Check ranges.
if partial.len_sent() != sent_len || partial.len_received() != recv_len {
return Err(DecodeError(ErrorRepr::VerifyTranscriptLength));
}
Ok(())
}
fn verify_with_keys(
key: [u8; 16],
iv: [u8; 4],
decoding_ranges: &RangeSet<usize>,
transcript: &TlsTranscript,
) -> Result<Vec<u8>, DecodeError> {
let mut plaintexts = Vec::with_capacity(decoding_ranges.len());
let mut position = 0_usize;
let recv_data = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv_data {
let current = position..position + record.ciphertext.len();
if !current.is_subset(decoding_ranges) {
position += record.ciphertext.len();
continue;
}
let nonce = record
.explicit_nonce
.clone()
.try_into()
.expect("explicit nonce should be 8 bytes");
let plaintext = aes_apply_keystream(key, iv, nonce, &record.ciphertext);
let record_decoding_range = decoding_ranges.intersection(&current);
for r in record_decoding_range.iter_ranges() {
let shifted = r.start - position..r.end - position;
plaintexts.extend_from_slice(&plaintext[shifted]);
}
position += record.ciphertext.len()
}
Ok(plaintexts)
}
fn aes_apply_keystream(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], msg: &[u8]) -> Vec<u8> {
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let start_ctr = 2;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
let mut out = msg.to_vec();
cipher
.try_seek(start_ctr * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut out);
out
}
/// A decoding error.
#[derive(Debug, thiserror::Error)]
#[error("decode error: {0}")]
pub(crate) struct DecodeError(#[source] ErrorRepr);
impl DecodeError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("missing partial transcript")]
MissingPartialTranscript,
#[error("length of partial transcript does not match expected length")]
VerifyTranscriptLength,
#[error("provided transcript does not match exptected")]
InconsistentTranscript,
#[error("trying to get plaintext, but it is missing")]
MissingPlaintext,
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
TranscriptRefs,
decode::{DecodeError, ErrorRepr, decode_transcript, verify_transcript},
},
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY},
transcript::{ContentType, Direction, PartialTranscript, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding.clone());
decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_fail(
decoding: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding.clone());
let err = decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
#[rstest]
#[tokio::test]
async fn test_decode_all(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding_full.clone());
decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_all_fail(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding_full.clone());
let err = decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
async fn decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
partial: PartialTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding;
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut transcript_refs_verifier = transcript_refs.clone();
let mut transcript_refs_prover = transcript_refs;
let key: [u8; 16] = KEY;
let iv: [u8; 4] = IV;
let (key_prover, iv_prover) = assign_keys(&mut prover, key, iv, Role::Prover);
let (key_verifier, iv_verifier) = assign_keys(&mut verifier, key, iv, Role::Verifier);
assign_transcript(
&mut prover,
Role::Prover,
&transcript,
&mut transcript_refs_prover,
);
assign_transcript(
&mut verifier,
Role::Verifier,
&transcript,
&mut transcript_refs_verifier,
);
decode_transcript(
&mut prover,
key_prover,
iv_prover,
(&sent, &recv),
&mut transcript_refs_prover,
)
.unwrap();
decode_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
&mut transcript_refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v),
)
.unwrap();
verify_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
Some(&partial),
&mut transcript_refs_verifier,
&transcript,
)
}
fn assign_keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
fn assign_transcript(
vm: &mut dyn Vm<Binary>,
role: Role,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) {
let mut pos = 0_usize;
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in sent {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Sent, &index, plaintext_ref);
pos += record.ciphertext.len();
}
pos = 0;
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Received, &index, plaintext_ref);
pos += record.ciphertext.len();
}
}
fn partial(
transcript: &TlsTranscript,
decoding: (RangeSet<usize>, RangeSet<usize>),
) -> PartialTranscript {
let (sent, recv) = decoding;
transcript.to_transcript().unwrap().to_partial(sent, recv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(600..1100));
sent.union_mut(&(3450..4000));
recv.union_mut(&(2000..3000));
recv.union_mut(&(4800..4900));
recv.union_mut(&(6000..7000));
(sent, recv)
}
#[fixture]
fn decoding_full(transcript: TlsTranscript) -> (RangeSet<usize>, RangeSet<usize>) {
let transcript = transcript.to_transcript().unwrap();
let (len_sent, len_recv) = transcript.len();
let sent = (0..len_sent).into();
let recv = (0..len_recv).into();
(sent, recv)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 2200;
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
recv[WRONG_BYTE_INDEX] = recv[WRONG_BYTE_INDEX].wrapping_add(1);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
TranscriptRefs::new(sent_len, recv_len)
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,530 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_memory_core::{
MemoryType, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{RangeSet, Subset, UnionMut};
use serde::{Deserialize, Serialize};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256, TypedHash},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingProvider, EncodingProviderError, EncodingTree,
EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
pub(crate) const ENCODING_SIZE: usize = 128;
pub(crate) trait EncodingVm<T: MemoryType>: EncodingMemory<T> + Vm<T> {}
impl<T: MemoryType, U> EncodingVm<T> for U where U: EncodingMemory<T> + Vm<T> {}
pub(crate) trait EncodingMemory<T: MemoryType> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8>;
}
impl<T> EncodingMemory<Binary> for mpz_zk::Prover<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let macs = self.get_macs(v).expect("macs should be available");
encodings.extend(macs.iter().flat_map(|mac| mac.as_bytes()));
}
encodings
}
}
impl<T> EncodingMemory<Binary> for mpz_zk::Verifier<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let keys = self.get_keys(v).expect("keys should be available");
encodings.extend(keys.iter().flat_map(|key| key.as_block().as_bytes()));
}
encodings
}
}
/// The encoding adjustments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Encodings {
pub(crate) sent: Vec<u8>,
pub(crate) recv: Vec<u8>,
}
/// Creates encoding commitments.
#[derive(Debug)]
pub(crate) struct EncodingCreator {
hash_id: Option<HashAlgId>,
sent: RangeSet<usize>,
recv: RangeSet<usize>,
idxs: Vec<(Direction, RangeSet<usize>)>,
}
impl EncodingCreator {
/// Creates a new encoding creator.
///
/// # Arguments
///
/// * `hash_id` - The id of the hash algorithm.
/// * `idxs` - The indices for encoding commitments.
pub(crate) fn new(hash_id: Option<HashAlgId>, idxs: Vec<(Direction, RangeSet<usize>)>) -> Self {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (direction, idx) in idxs.iter() {
for range in idx.iter_ranges() {
match direction {
Direction::Sent => sent.union_mut(&range),
Direction::Received => recv.union_mut(&range),
}
}
}
Self {
hash_id,
sent,
recv,
idxs,
}
}
/// Receives the encodings using the provided MACs from the encoding memory.
///
/// The MACs must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `encodings` - The encoding adjustments.
/// * `transcript_refs` - The transcript references.
pub(crate) fn receive(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
encodings: Encodings,
transcript_refs: &TranscriptRefs,
) -> Result<(TypedHash, EncodingTree), EncodingError> {
let Some(id) = self.hash_id else {
return Err(EncodingError(ErrorRepr::MissingHashId));
};
let hasher: &(dyn HashAlgorithm + Send + Sync) = match id {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(EncodingError(ErrorRepr::UnsupportedHashAlg(alg)));
}
};
let Encodings {
sent: mut sent_adjust,
recv: mut recv_adjust,
} = encodings;
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_adjust, &mut recv_adjust)?;
let provider = Provider::new(sent_adjust, &self.sent, recv_adjust, &self.recv);
let tree = EncodingTree::new(hasher, self.idxs.iter(), &provider)?;
let root = tree.root();
Ok((root, tree))
}
/// Transfers the encodings using the provided secret and keys from the
/// encoding memory.
///
/// The keys must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `secret` - The encoder secret.
/// * `transcript_refs` - The transcript references.
pub(crate) fn transfer(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
secret: EncoderSecret,
transcript_refs: &TranscriptRefs,
) -> Result<Encodings, EncodingError> {
let encoder = new_encoder(&secret);
let mut sent_zero = Vec::with_capacity(self.sent.len() * ENCODING_SIZE);
let mut recv_zero = Vec::with_capacity(self.recv.len() * ENCODING_SIZE);
for range in self.sent.iter_ranges() {
encoder.encode_range(Direction::Sent, range, &mut sent_zero);
}
for range in self.recv.iter_ranges() {
encoder.encode_range(Direction::Received, range, &mut recv_zero);
}
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_zero, &mut recv_zero)?;
let encodings = Encodings {
sent: sent_zero,
recv: recv_zero,
};
Ok(encodings)
}
}
/// Adjust encodings by transcript references.
///
/// # Arguments
///
/// * `sent` - The encodings for the sent bytes.
/// * `recv` - The encodings for the received bytes.
/// * `sent_adjust` - The adjustment bytes for the encodings of the sent bytes.
/// * `recv_adjust` - The adjustment bytes for the encodings of the received
/// bytes.
fn adjust(
sent: &[u8],
recv: &[u8],
sent_adjust: &mut [u8],
recv_adjust: &mut [u8],
) -> Result<(), EncodingError> {
assert_eq!(sent.len() % ENCODING_SIZE, 0);
assert_eq!(recv.len() % ENCODING_SIZE, 0);
if sent_adjust.len() != sent.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Sent,
expected: sent.len(),
got: sent_adjust.len(),
}
.into());
}
if recv_adjust.len() != recv.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Received,
expected: recv.len(),
got: recv_adjust.len(),
}
.into());
}
sent_adjust
.iter_mut()
.zip(sent)
.for_each(|(adjust, enc)| *adjust ^= enc);
recv_adjust
.iter_mut()
.zip(recv)
.for_each(|(adjust, enc)| *adjust ^= enc);
Ok(())
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
sent_range: RangeSet<usize>,
recv: Vec<u8>,
recv_range: RangeSet<usize>,
}
impl Provider {
fn new(
sent: Vec<u8>,
sent_range: &RangeSet<usize>,
recv: Vec<u8>,
recv_range: &RangeSet<usize>,
) -> Self {
assert_eq!(
sent.len(),
sent_range.len() * ENCODING_SIZE,
"length of sent encodings and their index length do not match"
);
assert_eq!(
recv.len(),
recv_range.len() * ENCODING_SIZE,
"length of received encodings and their index length do not match"
);
Self {
sent,
sent_range: sent_range.clone(),
recv,
recv_range: recv_range.clone(),
}
}
fn adjust(
&self,
direction: Direction,
range: &Range<usize>,
) -> Result<Range<usize>, EncodingProviderError> {
let internal_range = match direction {
Direction::Sent => &self.sent_range,
Direction::Received => &self.recv_range,
};
if !range.is_subset(internal_range) {
return Err(EncodingProviderError);
}
let shift = internal_range
.iter()
.take_while(|&el| el < range.start)
.count();
let translated = Range {
start: shift,
end: shift + range.len(),
};
Ok(translated)
}
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let range = self.adjust(direction, &range)?;
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("incorrect adjustment count for {direction}: expected {expected}, got {got}")]
IncorrectAdjustCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("missing hash id")]
MissingHashId,
#[error("unsupported hash algorithm for encoding commitment: {0}")]
UnsupportedHashAlg(HashAlgId),
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
use crate::commit::{
encoding::{ENCODING_SIZE, EncodingCreator, Encodings, Provider},
transcript::TranscriptRefs,
};
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_memory_core::{
FromRaw, Slice, ToRaw, Vector,
binary::{Binary, U8},
};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
hash::{HashAlgId, HashProvider},
transcript::{
Direction,
encoding::{EncoderSecret, EncodingCommitment, EncodingProvider},
},
};
#[rstest]
fn test_encoding_adjust(
index: (RangeSet<usize>, RangeSet<usize>),
transcript_refs: TranscriptRefs,
encoding_idxs: Vec<(Direction, RangeSet<usize>)>,
) {
let creator = EncodingCreator::new(Some(HashAlgId::SHA256), encoding_idxs);
let mock_memory = MockEncodingMemory;
let delta = Delta::new(Block::ONES);
let seed = [1_u8; 32];
let secret = EncoderSecret::new(seed, delta.as_block().to_bytes());
let adjustments = creator
.transfer(&mock_memory, secret, &transcript_refs)
.unwrap();
let (root, tree) = creator
.receive(&mock_memory, adjustments, &transcript_refs)
.unwrap();
// Check correctness of encoding protocol.
let mut idxs = Vec::new();
let (sent_range, recv_range) = index;
idxs.push((Direction::Sent, sent_range.clone()));
idxs.push((Direction::Received, recv_range.clone()));
let commitment = EncodingCommitment { root, secret };
let proof = tree.proof(idxs.iter()).unwrap();
// Here we set the trancscript plaintext to just be zeroes, which is possible
// because it is not determined by the encodings so far.
let sent = vec![0_u8; transcript_refs.max_len(Direction::Sent)];
let recv = vec![0_u8; transcript_refs.max_len(Direction::Received)];
let (idx_sent, idx_recv) = proof
.verify_with_provider(&HashProvider::default(), &commitment, &sent, &recv)
.unwrap();
assert_eq!(idx_sent, idxs[0].1);
assert_eq!(idx_recv, idxs[1].1);
}
#[rstest]
fn test_encoding_provider(index: (RangeSet<usize>, RangeSet<usize>), encodings: Encodings) {
let (sent_range, recv_range) = index;
let Encodings { sent, recv } = encodings;
let provider = Provider::new(sent, &sent_range, recv, &recv_range);
let mut encodings_sent = Vec::new();
let mut encodings_recv = Vec::new();
provider
.provide_encoding(Direction::Sent, 16..24, &mut encodings_sent)
.unwrap();
provider
.provide_encoding(Direction::Received, 56..64, &mut encodings_recv)
.unwrap();
let expected_sent = generate_encodings((16..24).into());
let expected_recv = generate_encodings((56..64).into());
assert_eq!(expected_sent, encodings_sent);
assert_eq!(expected_recv, encodings_recv);
}
#[fixture]
fn transcript_refs(index: (RangeSet<usize>, RangeSet<usize>)) -> TranscriptRefs {
let mut transcript_refs = TranscriptRefs::new(40, 64);
let dummy = |range: Range<usize>| {
Vector::<U8>::from_raw(Slice::from_range_unchecked(8 * range.start..8 * range.end))
};
for range in index.0.iter_ranges() {
transcript_refs.add(Direction::Sent, &range, dummy(range.clone()));
}
for range in index.1.iter_ranges() {
transcript_refs.add(Direction::Received, &range, dummy(range.clone()));
}
transcript_refs
}
#[fixture]
fn encodings(index: (RangeSet<usize>, RangeSet<usize>)) -> Encodings {
let sent = generate_encodings(index.0);
let recv = generate_encodings(index.1);
Encodings { sent, recv }
}
#[fixture]
fn encoding_idxs(
index: (RangeSet<usize>, RangeSet<usize>),
) -> Vec<(Direction, RangeSet<usize>)> {
let (sent, recv) = index;
vec![(Direction::Sent, sent), (Direction::Received, recv)]
}
#[fixture]
fn index() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
sent.union_mut(&(0..8));
sent.union_mut(&(16..24));
sent.union_mut(&(32..40));
let mut recv = RangeSet::default();
recv.union_mut(&(40..48));
recv.union_mut(&(56..64));
(sent, recv)
}
#[derive(Clone, Copy)]
struct MockEncodingMemory;
impl EncodingMemory<Binary> for MockEncodingMemory {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let ranges: Vec<Range<usize>> = values
.iter()
.map(|r| {
let range = r.to_raw().to_range();
range.start / 8..range.end / 8
})
.collect();
let ranges: RangeSet<usize> = ranges.into();
generate_encodings(ranges)
}
}
fn generate_encodings(index: RangeSet<usize>) -> Vec<u8> {
let mut out = Vec::new();
for el in index.iter() {
out.extend_from_slice(&[el as u8; ENCODING_SIZE]);
}
out
}
}

View File

@@ -9,44 +9,182 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
Direction, Idx,
Direction,
hash::{PlaintextHash, PlaintextHashSecret},
},
};
use crate::{Role, commit::transcript::TranscriptRefs};
/// Future which will resolve to the committed hash values.
/// Creates plaintext hashes.
#[derive(Debug)]
pub(crate) struct HashCommitFuture {
#[allow(clippy::type_complexity)]
futs: Vec<(
Direction,
Idx,
HashAlgId,
DecodeFutureTyped<BitVec, Vec<u8>>,
)>,
pub(crate) struct PlaintextHasher {
ranges: Vec<HashRange>,
}
impl HashCommitFuture {
impl PlaintextHasher {
/// Creates a new instance.
///
/// # Arguments
///
/// * `indices` - The hash indices.
pub(crate) fn new<'a>(
indices: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
) -> Self {
let mut ranges = Vec::new();
for (direction, index, id) in indices {
let hash_range = HashRange::new(*direction, index.clone(), *id);
ranges.push(hash_range);
}
Self { ranges }
}
/// Prove plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn prove(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<(HashFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Prover, transcript_refs)?;
let mut futures = Vec::new();
let mut secrets = Vec::new();
for ((range, hash_ref), blinder_ref) in self.ranges.iter().zip(hash_refs).zip(blinders) {
let blinder: Blinder = rand::random();
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut));
secrets.push(PlaintextHashSecret {
direction: range.direction,
idx: range.range.clone(),
blinder,
alg: range.id,
});
}
let hashes = HashFuture { futures };
Ok((hashes, secrets))
}
/// Verify plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn verify(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<HashFuture, HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Verifier, transcript_refs)?;
let mut futures = Vec::new();
for ((range, hash_ref), blinder) in self.ranges.iter().zip(hash_refs).zip(blinders) {
vm.commit(blinder)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut))
}
let hashes = HashFuture { futures };
Ok(hashes)
}
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn commit(
vm: &mut dyn Vm<Binary>,
ranges: &[HashRange],
role: Role,
refs: &TranscriptRefs,
) -> Result<(Vec<Array<U8, 32>>, Vec<Vector<U8>>), HashCommitError> {
let mut hashers = HashMap::new();
let mut hash_refs = Vec::new();
let mut blinders = Vec::new();
for HashRange {
direction,
range,
id,
} in ranges.iter()
{
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
}
let hash = match *id {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(id).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(id, hasher.clone());
hasher
};
for plaintext in refs.get(*direction, range) {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
id => {
return Err(HashCommitError::unsupported_alg(id));
}
};
hash_refs.push(hash);
blinders.push(blinder);
}
Ok((hash_refs, blinders))
}
/// Future which will resolve to the committed hash values.
#[derive(Debug)]
pub(crate) struct HashFuture {
futures: Vec<(HashRange, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl HashFuture {
/// Tries to receive the value, returning an error if the value is not
/// ready.
pub(crate) fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
let mut output = Vec::new();
for (direction, idx, alg, mut fut) in self.futs {
for (hash_range, mut fut) in self.futures {
let hash = fut
.try_recv()
.map_err(|_| HashCommitError::decode())?
.ok_or_else(HashCommitError::decode)?;
output.push(PlaintextHash {
direction,
idx,
direction: hash_range.direction,
idx: hash_range.range,
hash: TypedHash {
alg,
alg: hash_range.id,
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
},
});
@@ -56,98 +194,21 @@ impl HashCommitFuture {
}
}
/// Prove plaintext hash commitments.
pub(crate) fn prove_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
) -> Result<(HashCommitFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let mut futs = Vec::new();
let mut secrets = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Prover, refs, idxs)?
{
let blinder: Blinder = rand::random();
#[derive(Debug, Clone)]
struct HashRange {
direction: Direction,
range: RangeSet<usize>,
id: HashAlgId,
}
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx.clone(), alg, hash_fut));
secrets.push(PlaintextHashSecret {
impl HashRange {
fn new(direction: Direction, range: RangeSet<usize>, id: HashAlgId) -> Self {
Self {
direction,
idx,
blinder,
alg,
});
}
Ok((HashCommitFuture { futs }, secrets))
}
/// Verify plaintext hash commitments.
pub(crate) fn verify_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
) -> Result<HashCommitFuture, HashCommitError> {
let mut futs = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Verifier, refs, idxs)?
{
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx, alg, hash_fut));
}
Ok(HashCommitFuture { futs })
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn hash_commit_inner(
vm: &mut dyn Vm<Binary>,
role: Role,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
) -> Result<Vec<(Direction, Idx, HashAlgId, Array<U8, 32>, Vector<U8>)>, HashCommitError> {
let mut output = Vec::new();
let mut hashers = HashMap::new();
for (direction, idx, alg) in idxs {
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
range,
id,
}
let hash = match alg {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, hasher.clone());
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => {
return Err(HashCommitError::unsupported_alg(alg));
}
};
output.push((direction, idx, alg, hash, blinder));
}
Ok(output)
}
/// Error type for hash commitments.
@@ -196,3 +257,148 @@ impl From<VmError> for HashCommitError {
Self(ErrorRepr::Vm(value))
}
}
#[cfg(test)]
mod test {
use crate::{
Role,
commit::{hash::PlaintextHasher, transcript::TranscriptRefs},
};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use sha2::Digest;
use tlsn_core::{hash::HashAlgId, transcript::Direction};
#[rstest]
#[tokio::test]
async fn test_hasher() {
let mut sent1 = RangeSet::default();
sent1.union_mut(&(1..6));
sent1.union_mut(&(11..16));
let mut sent2 = RangeSet::default();
sent2.union_mut(&(22..25));
let mut recv = RangeSet::default();
recv.union_mut(&(20..25));
let hash_ranges = [
(Direction::Sent, sent1, HashAlgId::SHA256),
(Direction::Sent, sent2, HashAlgId::SHA256),
(Direction::Received, recv, HashAlgId::SHA256),
];
let mut refs_prover = TranscriptRefs::new(1000, 1000);
let mut refs_verifier = TranscriptRefs::new(1000, 1000);
let values = [
b"abcde".to_vec(),
b"vwxyz".to_vec(),
b"xxx".to_vec(),
b"12345".to_vec(),
];
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut values_iter = values.iter();
for (direction, idx, _) in hash_ranges.iter() {
for range in idx.iter_ranges() {
let value = values_iter.next().unwrap();
let ref_prover = assign(Role::Prover, &mut prover, value.clone());
refs_prover.add(*direction, &range, ref_prover);
let ref_verifier = assign(Role::Verifier, &mut verifier, value.clone());
refs_verifier.add(*direction, &range, ref_verifier);
}
}
let hasher_prover = PlaintextHasher::new(hash_ranges.iter());
let hasher_verifier = PlaintextHasher::new(hash_ranges.iter());
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let (prover_hashes, prover_secrets) =
hasher_prover.prove(&mut prover, &refs_prover).unwrap();
let verifier_hashes = hasher_verifier
.verify(&mut verifier, &refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let prover_hashes = prover_hashes.try_recv().unwrap();
let verifier_hashes = verifier_hashes.try_recv().unwrap();
assert_eq!(prover_hashes, verifier_hashes);
let values_per_commitment = [b"abcdevwxyz".to_vec(), b"xxx".to_vec(), b"12345".to_vec()];
for ((value, hash), secret) in values_per_commitment
.iter()
.zip(prover_hashes)
.zip(prover_secrets)
{
let blinder = secret.blinder.as_bytes();
let mut blinded_value = value.clone();
blinded_value.extend_from_slice(blinder);
let expected_hash = sha256(&blinded_value);
let hash: Vec<u8> = hash.hash.value.into();
assert_eq!(expected_hash, hash);
}
}
fn assign(role: Role, vm: &mut dyn Vm<Binary>, value: Vec<u8>) -> Vector<U8> {
let reference: Vector<U8> = vm.alloc_vec(value.len()).unwrap();
if let Role::Prover = role {
vm.mark_private(reference).unwrap();
vm.assign(reference, value).unwrap();
} else {
vm.mark_blind(reference).unwrap();
}
vm.commit(reference).unwrap();
reference
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = sha2::Sha256::default();
hasher.update(data);
hasher.finalize().as_slice().to_vec()
}
#[fixture]
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
}

View File

@@ -1,204 +1,473 @@
use mpz_memory_core::{
MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError};
use rangeset::Intersection;
use tlsn_core::transcript::{Direction, Idx, PartialTranscript};
//! Transcript reference storage.
use std::ops::Range;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{Difference, Disjoint, RangeSet, Subset, UnionMut};
use tlsn_core::transcript::Direction;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub(crate) struct TranscriptRefs {
sent: Vec<Vector<U8>>,
recv: Vec<Vector<U8>>,
sent: RefStorage,
recv: RefStorage,
}
impl TranscriptRefs {
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self {
/// Creates a new instance.
///
/// # Arguments
///
/// `sent_max_len` - The maximum length of the sent transcript in bytes.
/// `recv_max_len` - The maximum length of the received transcript in bytes.
pub(crate) fn new(sent_max_len: usize, recv_max_len: usize) -> Self {
let sent = RefStorage::new(sent_max_len);
let recv = RefStorage::new(recv_max_len);
Self { sent, recv }
}
/// Returns the sent plaintext references.
pub(crate) fn sent(&self) -> &[Vector<U8>] {
&self.sent
/// Adds new references to the transcript refs.
///
/// New transcript references are only added if none of them are already
/// present.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
/// * `refs` - The new transcript refs.
pub(crate) fn add(&mut self, direction: Direction, index: &Range<usize>, refs: Vector<U8>) {
match direction {
Direction::Sent => self.sent.add(index, refs),
Direction::Received => self.recv.add(index, refs),
}
}
/// Returns the received plaintext references.
pub(crate) fn recv(&self) -> &[Vector<U8>] {
&self.recv
/// Marks references of the transcript as decoded.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn mark_decoded(&mut self, direction: Direction, index: &RangeSet<usize>) {
match direction {
Direction::Sent => self.sent.mark_decoded(index),
Direction::Received => self.recv.mark_decoded(index),
}
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub(crate) fn get(&self, direction: Direction, idx: &Idx) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
/// Returns plaintext references for some index.
///
/// Queries that cannot or only partially be satisfied will return an empty
/// vector.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn get(&self, direction: Direction, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
match direction {
Direction::Sent => self.sent.get(index),
Direction::Received => self.recv.get(index),
}
}
/// Computes the subset of `index` which is missing.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn compute_missing(
&self,
direction: Direction,
index: &RangeSet<usize>,
) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.compute_missing(index),
Direction::Received => self.recv.compute_missing(index),
}
}
/// Returns the maximum length of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn max_len(&self, direction: Direction) -> usize {
match direction {
Direction::Sent => self.sent.max_len(),
Direction::Received => self.recv.max_len(),
}
}
/// Returns the decoded ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn decoded(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.decoded(),
Direction::Received => self.recv.decoded(),
}
}
/// Returns the set ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
#[cfg(test)]
pub(crate) fn index(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.index(),
Direction::Received => self.recv.index(),
}
}
}
/// Inner storage for transcript references.
///
/// Saves transcript references by maintaining an `index` and an `offset`. The
/// offset translates from `index` to some memory location and contains
/// information about possibly non-contigious memory locations. The storage is
/// bit-addressed but the API works with ranges over bytes.
#[derive(Debug, Clone)]
struct RefStorage {
index: RangeSet<usize>,
decoded: RangeSet<usize>,
offset: Vec<isize>,
max_len: usize,
}
impl RefStorage {
fn new(max_len: usize) -> Self {
Self {
index: RangeSet::default(),
decoded: RangeSet::default(),
offset: Vec::default(),
max_len: 8 * max_len,
}
}
fn add(&mut self, index: &Range<usize>, data: Vector<U8>) {
assert!(
index.start < index.end,
"Range should be valid for adding to reference storage"
);
assert_eq!(
index.len(),
data.len(),
"Provided index and vm references should have the same length"
);
let bit_index = 8 * index.start..8 * index.end;
assert!(
bit_index.is_disjoint(&self.index),
"Parts of the provided index have already been computed"
);
assert!(
bit_index.end <= self.max_len,
"Provided index should be smaller than max_len"
);
if bit_index.end > self.offset.len() {
self.offset.resize(bit_index.end, 0);
}
let refs = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let mem_address = data.to_raw().ptr().as_usize() as isize;
let offset = mem_address - bit_index.start as isize;
// Computes the transcript range for each reference.
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
self.index.union_mut(&bit_index);
self.offset[bit_index].fill(offset);
}
let mut slices = Vec::new();
let (mut slice, mut slice_range) = slice_iter.next()?;
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
fn mark_decoded(&mut self, index: &RangeSet<usize>) {
let bit_index = to_bit_index(index);
self.decoded.union_mut(&bit_index);
}
fn get(&self, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
let bit_index = to_bit_index(index);
if bit_index.is_empty() || !bit_index.is_subset(&self.index) {
return Vec::new();
}
// Partition rangeset into ranges mapping to possibly disjunct memory locations.
//
// If the offset changes during iteration of a single range, it means that the
// backing memory is non-contigious and we need to split that range.
let mut transcript_refs = Vec::new();
for idx in bit_index.iter_ranges() {
let mut start = idx.start;
let mut end = idx.start;
let mut offset = self.offset[start];
for k in idx {
let next_offset = self.offset[k];
if next_offset == offset {
end += 1;
continue;
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
// to the next slice.
if range.end <= slice_range.end {
break;
} else {
(slice, slice_range) = slice_iter.next()?;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
start = k;
end = k + 1;
offset = next_offset;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
}
Some(slices)
transcript_refs
}
fn compute_missing(&self, index: &RangeSet<usize>) -> RangeSet<usize> {
let byte_index = to_byte_index(&self.index);
index.difference(&byte_index)
}
fn decoded(&self) -> RangeSet<usize> {
to_byte_index(&self.decoded)
}
fn max_len(&self) -> usize {
self.max_len / 8
}
#[cfg(test)]
fn index(&self) -> RangeSet<usize> {
to_byte_index(&self.index)
}
}
/// Decodes the transcript.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
sent: &Idx,
recv: &Idx,
refs: &TranscriptRefs,
) -> Result<(), VmError> {
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, recv)
.expect("index is in bounds");
fn to_bit_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut bit_index = RangeSet::default();
for slice in sent_refs.into_iter().chain(recv_refs) {
// Drop the future, we don't need it.
drop(vm.decode(slice)?);
for r in index.iter_ranges() {
bit_index.union_mut(&(8 * r.start..8 * r.end));
}
Ok(())
bit_index
}
/// Verifies a partial transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
transcript: &PartialTranscript,
refs: &TranscriptRefs,
) -> Result<(), InconsistentTranscript> {
let sent_refs = refs
.get(Direction::Sent, transcript.sent_authed())
.expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, transcript.received_authed())
.expect("index is in bounds");
fn to_byte_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut byte_index = RangeSet::default();
let mut authenticated_data = Vec::new();
for data in sent_refs.into_iter().chain(recv_refs) {
let plaintext = vm
.get(data)
.expect("reference is valid")
.expect("plaintext is decoded");
authenticated_data.extend_from_slice(&plaintext);
for r in index.iter_ranges() {
let start = r.start;
let end = r.end;
assert!(
start.trailing_zeros() >= 3,
"start range should be divisible by 8"
);
assert!(
end.trailing_zeros() >= 3,
"end range should be divisible by 8"
);
let start = start >> 3;
let end = end >> 3;
byte_index.union_mut(&(start..end));
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in transcript.sent_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
}
for range in transcript.received_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(InconsistentTranscript {});
}
Ok(())
byte_index
}
/// Error for [`verify_transcript`].
#[derive(Debug, thiserror::Error)]
#[error("inconsistent transcript")]
pub(crate) struct InconsistentTranscript {}
#[cfg(test)]
mod tests {
use super::TranscriptRefs;
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
use rangeset::RangeSet;
use crate::commit::transcript::RefStorage;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use std::ops::Range;
use tlsn_core::transcript::{Direction, Idx};
// TRANSCRIPT_REFS:
//
// 48..96 -> 6 slots
// 112..176 -> 8 slots
// 240..288 -> 6 slots
// 352..392 -> 5 slots
// 440..480 -> 5 slots
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
#[rstest]
fn test_storage_add(
max_len: usize,
ranges: [Range<usize>; 6],
offsets: [isize; 6],
storage: RefStorage,
) {
let bit_ranges: Vec<Range<usize>> = ranges.iter().map(|r| 8 * r.start..8 * r.end).collect();
let bit_offsets: Vec<isize> = offsets.iter().map(|o| 8 * o).collect();
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28];
let mut expected_index: RangeSet<usize> = RangeSet::default();
// 1. Take slots 0..4, 4 slots -> 48..80 (4)
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4)
// 3. Take slots 14..16, 2 slots -> 240..256 (2)
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3)
//
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
const EXPECTED_REFS: &[Range<usize>] =
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
expected_index.union_mut(&bit_ranges[0]);
expected_index.union_mut(&bit_ranges[1]);
#[test]
fn test_transcript_refs_get() {
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_index.union_mut(&bit_ranges[2]);
expected_index.union_mut(&bit_ranges[3]);
let transcript_refs = TranscriptRefs {
sent: transcript_refs.clone(),
recv: transcript_refs,
};
expected_index.union_mut(&bit_ranges[4]);
expected_index.union_mut(&bit_ranges[5]);
assert_eq!(storage.index, expected_index);
let vm_refs = transcript_refs
.get(Direction::Sent, &idx_fixture())
.unwrap();
let end = expected_index.end().unwrap();
let mut expected_offset = vec![0_isize; end];
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_offset[bit_ranges[0].clone()].fill(bit_offsets[0]);
expected_offset[bit_ranges[1].clone()].fill(bit_offsets[1]);
assert_eq!(
vm_refs.len(),
expected_refs.len(),
"Length of actual and expected refs are not equal"
);
expected_offset[bit_ranges[2].clone()].fill(bit_offsets[2]);
expected_offset[bit_ranges[3].clone()].fill(bit_offsets[3]);
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
assert_eq!(expected, actual);
}
expected_offset[bit_ranges[4].clone()].fill(bit_offsets[4]);
expected_offset[bit_ranges[5].clone()].fill(bit_offsets[5]);
assert_eq!(storage.offset, expected_offset);
assert_eq!(storage.decoded, RangeSet::default());
assert_eq!(storage.max_len, 8 * max_len);
}
fn idx_fixture() -> Idx {
let set = RangeSet::from(IDXS);
Idx::builder().union(&set).build()
#[rstest]
fn test_storage_get(ranges: [Range<usize>; 6], offsets: [isize; 6], storage: RefStorage) {
let mut index = RangeSet::default();
ranges.iter().for_each(|r| index.union_mut(r));
let data = storage.get(&index);
let mut data_recovered = Vec::new();
for (r, o) in ranges.iter().zip(offsets) {
data_recovered.push(vec(r.start as isize + o..r.end as isize + o));
}
// Merge possibly adjacent vectors.
//
// Two vectors are adjacent if
//
// - vectors are adjacent in memory.
// - transcript ranges of those vectors are adjacent, too.
let mut range_iter = ranges.iter();
let mut vec_iter = data_recovered.iter();
let mut data_expected = Vec::new();
let mut current_vec = vec_iter.next().unwrap().to_raw().to_range();
let mut current_range = range_iter.next().unwrap();
for (r, v) in range_iter.zip(vec_iter) {
let v_range = v.to_raw().to_range();
let start = v_range.start;
let end = v_range.end;
if current_vec.end == start && current_range.end == r.start {
current_vec.end = end;
} else {
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
current_vec = start..end;
current_range = r;
}
}
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
assert_eq!(data, data_expected);
}
#[rstest]
fn test_storage_compute_missing(storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(6..12));
range.union_mut(&(18..21));
range.union_mut(&(22..25));
range.union_mut(&(50..60));
let missing = storage.compute_missing(&range);
let mut missing_expected = RangeSet::default();
missing_expected.union_mut(&(8..12));
missing_expected.union_mut(&(20..21));
missing_expected.union_mut(&(50..60));
assert_eq!(missing, missing_expected);
}
#[rstest]
fn test_mark_decoded(mut storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(14..17));
range.union_mut(&(30..37));
storage.mark_decoded(&range);
let decoded = storage.decoded();
assert_eq!(range, decoded);
}
#[fixture]
fn max_len() -> usize {
1000
}
#[fixture]
fn ranges() -> [Range<usize>; 6] {
let r1 = 0..5;
let r2 = 5..8;
let r3 = 12..20;
let r4 = 22..26;
let r5 = 30..35;
let r6 = 35..38;
[r1, r2, r3, r4, r5, r6]
}
#[fixture]
fn offsets() -> [isize; 6] {
[7, 9, 20, 18, 30, 30]
}
// expected memory ranges: 8 * ranges + 8 * offsets
// 1. 56..96 do not merge with next one, because not adjacent in memory
// 2. 112..136
// 3. 256..320 do not merge with next one, adjacent in memory, but the ranges
// itself are not
// 4. 320..352
// 5. 480..520 merge with next one
// 6 520..544
//
//
// 1. 56..96, length: 5
// 2. 112..136, length: 3
// 3. 256..320, length: 8
// 4. 320..352, length: 4
// 5. 480..544, length: 8
#[fixture]
fn storage(max_len: usize, ranges: [Range<usize>; 6], offsets: [isize; 6]) -> RefStorage {
let [r1, r2, r3, r4, r5, r6] = ranges;
let [o1, o2, o3, o4, o5, o6] = offsets;
let mut storage = RefStorage::new(max_len);
storage.add(&r1, vec(r1.start as isize + o1..r1.end as isize + o1));
storage.add(&r2, vec(r2.start as isize + o2..r2.end as isize + o2));
storage.add(&r3, vec(r3.start as isize + o3..r3.end as isize + o3));
storage.add(&r4, vec(r4.start as isize + o4..r4.end as isize + o4));
storage.add(&r5, vec(r5.start as isize + o5..r5.end as isize + o5));
storage.add(&r6, vec(r6.start as isize + o6..r6.end as isize + o6));
storage
}
fn vec(range: Range<isize>) -> Vector<U8> {
let range = 8 * range.start as usize..8 * range.end as usize;
Vector::from_raw(Slice::from_range_unchecked(range))
}
}

View File

@@ -5,6 +5,8 @@ use semver::Version;
use serde::{Deserialize, Serialize};
use std::error::Error;
pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
// Default maximum number of TLS records to allow.
@@ -108,7 +110,7 @@ impl ProtocolConfig {
/// Protocol configuration validator used by checker (i.e. verifier) to perform
/// compatibility check with the peer's (i.e. the prover's) configuration.
#[derive(derive_builder::Builder, Clone, Debug)]
#[derive(derive_builder::Builder, Clone, Debug, Serialize, Deserialize)]
pub struct ProtocolConfigValidator {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
@@ -190,22 +192,22 @@ impl ProtocolConfigValidator {
max_sent_records: Option<usize>,
max_recv_records_online: Option<usize>,
) -> Result<(), ProtocolConfigError> {
if let Some(max_sent_records) = max_sent_records {
if max_sent_records > self.max_sent_records {
return Err(ProtocolConfigError::max_record_count(format!(
"max_sent_records {} is greater than the configured limit {}",
max_sent_records, self.max_sent_records,
)));
}
if let Some(max_sent_records) = max_sent_records
&& max_sent_records > self.max_sent_records
{
return Err(ProtocolConfigError::max_record_count(format!(
"max_sent_records {} is greater than the configured limit {}",
max_sent_records, self.max_sent_records,
)));
}
if let Some(max_recv_records_online) = max_recv_records_online {
if max_recv_records_online > self.max_recv_records_online {
return Err(ProtocolConfigError::max_record_count(format!(
"max_recv_records_online {} is greater than the configured limit {}",
max_recv_records_online, self.max_recv_records_online,
)));
}
if let Some(max_recv_records_online) = max_recv_records_online
&& max_recv_records_online > self.max_recv_records_online
{
return Err(ProtocolConfigError::max_record_count(format!(
"max_recv_records_online {} is greater than the configured limit {}",
max_recv_records_online, self.max_recv_records_online,
)));
}
Ok(())
@@ -231,15 +233,17 @@ impl ProtocolConfigValidator {
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NetworkSetting {
/// Prefers a bandwidth-heavy protocol.
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Prefers a latency-heavy protocol.
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
Latency,
}
impl Default for NetworkSetting {
fn default() -> Self {
Self::Bandwidth
Self::Latency
}
}

View File

@@ -1,238 +0,0 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_common::Context;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::HashAlgorithm,
transcript::{
Direction, Idx,
encoding::{
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
EncodingTree, EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers the encodings using the provided seed and keys.
///
/// The keys must be consistent with the global delta used in the encodings.
pub(crate) async fn transfer<'a>(
ctx: &mut Context,
refs: &TranscriptRefs,
delta: &Delta,
f: impl Fn(Vector<U8>) -> &'a [Key],
) -> Result<EncodingCommitment, EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
let encoder = new_encoder(&secret);
let sent_keys: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
let recv_keys: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
encoder.encode_range(
Direction::Sent,
0..sent_keys.len() / ENCODING_SIZE,
&mut sent_encoding,
);
encoder.encode_range(
Direction::Received,
0..recv_keys.len() / ENCODING_SIZE,
&mut recv_encoding,
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
ctx.io_mut()
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
let root = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret.clone()).await?;
Ok(EncodingCommitment {
root,
secret: secret.clone(),
})
}
/// Receives the encodings using the provided MACs.
///
/// The MACs must be consistent with the global delta used in the encodings.
pub(crate) async fn receive<'a>(
ctx: &mut Context,
hasher: &(dyn HashAlgorithm + Send + Sync),
refs: &TranscriptRefs,
f: impl Fn(Vector<U8>) -> &'a [Mac],
idxs: impl IntoIterator<Item = &(Direction, Idx)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
let Encodings { mut sent, mut recv } = ctx.io_mut().expect_next().await?;
let sent_macs: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_macs.len(),
got: sent.len(),
}
.into());
}
if recv.len() != recv_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_macs.len(),
got: recv.len(),
}
.into());
}
sent.iter_mut()
.zip(sent_macs)
.for_each(|(enc, mac)| *enc ^= mac);
recv.iter_mut()
.zip(recv_macs)
.for_each(|(enc, mac)| *enc ^= mac);
let provider = Provider { sent, recv };
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
ctx.io_mut().send(root.clone()).await?;
let secret = ctx.io_mut().expect_next().await?;
let commitment = EncodingCommitment { root, secret };
Ok((commitment, tree))
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
recv: Vec<u8>,
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
if end > encodings.len() {
return Err(EncodingProviderError);
}
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}

View File

@@ -7,7 +7,6 @@
pub(crate) mod commit;
pub mod config;
pub(crate) mod context;
pub(crate) mod encoding;
pub(crate) mod ghash;
pub(crate) mod msg;
pub(crate) mod mux;

View File

@@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize};
use tlsn_core::connection::{ServerCertData, ServerName};
use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
@@ -10,5 +10,5 @@ pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,
/// Server identity data.
pub data: ServerCertData,
pub data: HandshakeData,
}

View File

@@ -10,48 +10,39 @@ pub use error::ProverError;
pub use future::ProverFuture;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use rand::Rng;
use rustls_pki_types::CertificateDer;
use serio::SinkExt;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
commit::{
commit_records,
hash::prove_hash,
transcript::{TranscriptRefs, decode_transcript},
},
commit::{ProvingState, TranscriptRefs},
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_attestation::{
Attestation, CryptoProvider, Secrets,
request::{Request, RequestConfig},
};
use tlsn_core::{
ProvePayload,
connection::ServerCertData,
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
@@ -174,8 +165,8 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
..
@@ -183,21 +174,40 @@ impl Prover<state::Setup> {
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = self.config.server_name();
let server_name =
TlsServerName::try_from(self.config.server_name().as_str()).map_err(|_| {
ProverError::config(format!(
"invalid server name: {}",
self.config.server_name()
))
})?;
TlsServerName::try_from(server_name.as_ref()).expect("name was validated");
let root_store = if let Some(root_store) = self.config.tls_config().root_store() {
let roots = root_store
.roots
.iter()
.map(|cert| {
let der = CertificateDer::from_slice(&cert.0);
anchor_from_trusted_cert(&der)
.map(|anchor| anchor.to_owned())
.map_err(ProverError::config)
})
.collect::<Result<Vec<_>, _>>()?;
tls_client::RootCertStore { roots }
} else {
tls_client::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
};
let config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(self.config.tls_config().root_store().clone());
.with_root_certificates(root_store);
let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
config
.with_single_cert(cert.clone(), key.clone())
.with_single_cert(
cert.iter()
.map(|cert| tls_client::Certificate(cert.0.clone()))
.collect(),
tls_client::PrivateKey(key.0.clone()),
)
.map_err(ProverError::config)?
} else {
config.with_no_client_auth()
@@ -263,28 +273,6 @@ impl Prover<state::Setup> {
)
.map_err(ProverError::zk)?;
// Prove received plaintext. Prover drops the proof output, as
// they trust themselves.
let (sent_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
let (recv_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
@@ -292,7 +280,9 @@ impl Prover<state::Setup> {
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let (sent_len, recv_len) = transcript.len();
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Prover {
config: self.config,
@@ -305,6 +295,10 @@ impl Prover<state::Setup> {
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred: false,
},
})
}
@@ -338,236 +332,57 @@ impl Prover<state::Committed> {
///
/// * `config` - The disclosure configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
pub async fn prove(&mut self, config: ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed {
mux_fut,
ctx,
vm,
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred,
..
} = &mut self.state;
let mut output = ProverOutput {
transcript_commitments: Vec::new(),
transcript_secrets: Vec::new(),
// Create and send prove payload.
let server_name = self.config.server_name();
let handshake = config
.server_identity()
.then(|| (server_name.clone(), HandshakeData::new(tls_transcript)));
let partial = if let Some((reveal_sent, reveal_recv)) = config.reveal() {
Some(transcript.to_partial(reveal_sent.clone(), reveal_recv.clone()))
} else {
None
};
let payload = ProvePayload {
server_identity: config.server_identity().then(|| {
(
self.config.server_name().clone(),
ServerCertData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
handshake: tls_transcript.handshake_data().clone(),
},
)
}),
transcript: config.transcript().cloned(),
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
let payload = ProveRequest::new(&config, partial, handshake);
// Send payload.
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
.await?;
if let Some(partial_transcript) = config.transcript() {
decode_transcript(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(ProverError::zk)?;
}
let mut hash_commitments = None;
if let Some(commit_config) = config.transcript_commit() {
if commit_config.has_encoding() {
let hasher: &(dyn HashAlgorithm + Send + Sync) =
match *commit_config.encoding_hash_alg() {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ProverError::config(format!(
"unsupported hash algorithm for encoding commitment: {alg}"
)));
}
};
let (commitment, tree) = mux_fut
.poll_with(
encoding::receive(
ctx,
hasher,
transcript_refs,
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
commit_config.iter_encoding(),
)
.map_err(ProverError::commit),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if commit_config.has_hash() {
hash_commitments = Some(
prove_hash(
vm,
transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
.await?;
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output)
}
/// Requests an attestation from the verifier.
///
/// # Arguments
///
/// * `config` - The attestation request configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize(
&mut self,
config: &RequestConfig,
) -> Result<(Attestation, Secrets), ProverError> {
#[allow(deprecated)]
self.notarize_with_provider(config, &CryptoProvider::default())
.await
}
/// Requests an attestation from the verifier.
///
/// # Arguments
///
/// * `config` - The attestation request configuration.
/// * `provider` - Cryptography provider.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize_with_provider(
&mut self,
config: &RequestConfig,
provider: &CryptoProvider,
) -> Result<(Attestation, Secrets), ProverError> {
let mut builder = ProveConfig::builder(self.transcript());
if let Some(config) = config.transcript_commit() {
// Temporarily, we reject attestation requests which contain hash commitments to
// subsets of the transcript. We do this because we want to preserve the
// obliviousness of the reference notary, and hash commitments currently leak
// the ranges which are being committed.
for ((direction, idx), _) in config.iter_hash() {
let len = match direction {
Direction::Sent => self.transcript().sent().len(),
Direction::Received => self.transcript().received().len(),
};
if idx.start() > 0 || idx.end() < len || idx.count() != 1 {
return Err(ProverError::attestation(
"hash commitments to subsets of the transcript are currently not supported in attestation requests",
));
}
}
builder.transcript_commit(config.clone());
}
let disclosure_config = builder.build().map_err(ProverError::attestation)?;
let ProverOutput {
transcript_commitments,
transcript_secrets,
..
} = self.prove(&disclosure_config).await?;
let state::Committed {
mux_fut,
ctx,
let proving_state = ProvingState::for_prover(
config,
tls_transcript,
transcript,
..
} = &mut self.state;
transcript_refs,
*encodings_transferred,
);
let mut builder = Request::builder(config);
builder
.server_name(self.config.server_name().clone())
.server_cert_data(ServerCertData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
handshake: tls_transcript.handshake_data().clone(),
})
.transcript(transcript.clone())
.transcript_commitments(transcript_secrets, transcript_commitments);
let (request, secrets) = builder.build(provider).map_err(ProverError::attestation)?;
let attestation = mux_fut
.poll_with(async {
debug!("sending attestation request");
ctx.io_mut().send(request.clone()).await?;
let attestation: Attestation = ctx.io_mut().expect_next().await?;
Ok::<_, ProverError>(attestation)
})
let (output, encodings_executed) = mux_fut
.poll_with(
proving_state
.prove(vm, ctx, zk_aes_ctr_sent, zk_aes_ctr_recv, keys.clone())
.map_err(ProverError::from),
)
.await?;
// Check the attestation is consistent with the Prover's view.
request
.validate(&attestation)
.map_err(ProverError::attestation)?;
Ok((attestation, secrets))
*encodings_transferred = encodings_executed;
Ok(output)
}
/// Closes the connection with the verifier.
@@ -618,7 +433,7 @@ fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk
delta,
);
let zk = Zk::new(rcot_recv.clone());
let zk = Zk::new(ZkProverConfig::default(), rcot_recv.clone());
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));

View File

@@ -1,14 +1,14 @@
use crate::config::{NetworkSetting, ProtocolConfig};
use mpc_tls::Config;
use rustls_pki_types::{CertificateDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, pem::PemObject};
use tls_core::{
anchors::{OwnedTrustAnchor, RootCertStore},
key,
use serde::{Deserialize, Serialize};
use tlsn_core::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tlsn_core::connection::ServerName;
use crate::config::{NetworkSetting, ProtocolConfig};
/// Configuration for the prover.
#[derive(Debug, Clone, derive_builder::Builder)]
#[derive(Debug, Clone, derive_builder::Builder, Serialize, Deserialize)]
pub struct ProverConfig {
/// The server DNS name.
#[builder(setter(into))]
@@ -67,31 +67,13 @@ impl ProverConfig {
}
/// Configuration for the prover's TLS connection.
#[derive(Debug, Clone)]
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
/// Root certificates.
root_store: RootCertStore,
root_store: Option<RootCertStore>,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
}
impl Default for TlsConfig {
fn default() -> Self {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
Self {
root_store,
client_auth: None,
}
}
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfig {
@@ -100,13 +82,13 @@ impl TlsConfig {
TlsConfigBuilder::default()
}
pub(crate) fn root_store(&self) -> &RootCertStore {
&self.root_store
pub(crate) fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> &Option<(Vec<key::Certificate>, key::PrivateKey)> {
pub fn client_auth(&self) -> &Option<(Vec<CertificateDer>, PrivateKeyDer)> {
&self.client_auth
}
}
@@ -115,7 +97,7 @@ impl TlsConfig {
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<key::Certificate>, key::PrivateKey)>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
@@ -138,74 +120,16 @@ impl TlsConfigBuilder {
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(&mut self, cert_key: (Vec<Vec<u8>>, Vec<u8>)) -> &mut Self {
let certs = cert_key
.0
.into_iter()
.map(key::Certificate)
.collect::<Vec<_>>();
self.client_auth = Some((certs, key::PrivateKey(cert_key.1)));
pub fn client_auth(&mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> &mut Self {
self.client_auth = Some(cert_key);
self
}
/// Sets a PEM-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth_pem(
&mut self,
cert_key: (Vec<Vec<u8>>, Vec<u8>),
) -> Result<&mut Self, TlsConfigError> {
let key = match PrivatePkcs8KeyDer::from_pem_slice(&cert_key.1) {
// Try to parse as PEM PKCS#8.
Ok(key) => (*key.secret_pkcs8_der()).to_vec(),
// Otherwise, try to parse as PEM PKCS#1.
Err(_) => match PrivatePkcs1KeyDer::from_pem_slice(&cert_key.1) {
Ok(key) => (*key.secret_pkcs1_der()).to_vec(),
Err(_) => return Err(ErrorRepr::InvalidKey.into()),
},
};
let certs = cert_key
.0
.iter()
.map(|c| {
let c =
CertificateDer::from_pem_slice(c).map_err(|_| ErrorRepr::InvalidCertificate)?;
Ok::<key::Certificate, TlsConfigError>(key::Certificate(c.as_ref().to_vec()))
})
.collect::<Result<Vec<_>, _>>()?;
self.client_auth = Some((certs, key::PrivateKey(key)));
Ok(self)
}
/// Builds the TLS configuration.
pub fn build(&self) -> Result<TlsConfig, TlsConfigError> {
pub fn build(self) -> Result<TlsConfig, TlsConfigError> {
Ok(TlsConfig {
root_store: self.root_store.clone().unwrap_or_else(|| {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
},
));
root_store
}),
client_auth: self.client_auth.clone(),
root_store: self.root_store,
client_auth: self.client_auth,
})
}
}
@@ -216,10 +140,5 @@ impl TlsConfigBuilder {
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error: {0}")]
enum ErrorRepr {
#[error("the certificate for client authentication is invalid")]
InvalidCertificate,
#[error("the private key for client authentication is invalid")]
InvalidKey,
}
#[error("tls config error")]
enum ErrorRepr {}

View File

@@ -1,8 +1,6 @@
use std::{error::Error, fmt};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use std::{error::Error, fmt};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -42,20 +40,6 @@ impl ProverError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn commit<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn attestation<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Attestation, source)
}
}
#[derive(Debug)]
@@ -65,7 +49,6 @@ enum ErrorKind {
Zk,
Config,
Commit,
Attestation,
}
impl fmt::Display for ProverError {
@@ -78,7 +61,6 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Attestation => f.write_str("attestation error")?,
}
if let Some(source) = &self.source {
@@ -125,8 +107,8 @@ impl From<ZkAesCtrError> for ProverError {
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for ProverError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -9,7 +9,7 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
@@ -42,6 +42,10 @@ pub struct Committed {
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -1,48 +1,45 @@
//! Verifier.
pub(crate) mod config;
mod config;
mod error;
pub mod state;
use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError};
use crate::{
Role,
commit::{
commit_records,
hash::verify_hash,
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
},
config::ProtocolConfig,
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
pub use tlsn_core::{
VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError,
webpki::ServerCertVerifier,
};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use serio::{SinkExt, stream::IoStreamExt};
use tls_core::{msgs::enums::ContentType, verify::WebPkiVerifier};
use tlsn_attestation::{Attestation, AttestationConfig, CryptoProvider, request::Request};
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use tlsn_core::{
ProvePayload,
connection::{ConnectionInfo, ServerName, TranscriptLength},
transcript::{TlsTranscript, TranscriptCommitment},
ProveRequest,
connection::{ConnectionInfo, ServerName},
transcript::{ContentType, TlsTranscript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument};
use crate::{
Role,
commit::{ProvingState, TranscriptRefs},
config::ProtocolConfig,
context::build_mt_context,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
@@ -152,59 +149,6 @@ impl Verifier<state::Initialized> {
})
}
/// Runs the verifier to completion and attests to the TLS session.
///
/// This is a convenience method which runs all the steps needed for
/// notarization.
///
/// # Arguments
///
/// * `socket` - The socket to the prover.
/// * `config` - The attestation configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
config: &AttestationConfig,
) -> Result<Attestation, VerifierError> {
#[allow(deprecated)]
self.notarize_with_provider(socket, config, &CryptoProvider::default())
.await
}
/// Runs the verifier to completion and attests to the TLS session.
///
/// This is a convenience method which runs all the steps needed for
/// notarization.
///
/// # Arguments
///
/// * `socket` - The socket to the prover.
/// * `config` - The attestation configuration.
/// * `provider` - Cryptography provider.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize_with_provider<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
config: &AttestationConfig,
provider: &CryptoProvider,
) -> Result<Attestation, VerifierError> {
let mut verifier = self.setup(socket).await?.run().await?;
#[allow(deprecated)]
let attestation = verifier.notarize_with_provider(config, provider).await?;
verifier.close().await?;
Ok(attestation)
}
/// Runs the TLS verifier to completion, verifying the TLS session.
///
/// This is a convenience method which runs all the steps needed for
@@ -238,8 +182,8 @@ impl Verifier<state::Setup> {
mut mux_fut,
delta,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
vm,
keys,
} = self.state;
@@ -280,27 +224,6 @@ impl Verifier<state::Setup> {
)
.map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let (sent_refs, sent_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
@@ -310,11 +233,30 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
sent_proof.verify().map_err(VerifierError::zk)?;
recv_proof.verify().map_err(VerifierError::zk)?;
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Verifier {
config: self.config,
@@ -327,6 +269,11 @@ impl Verifier<state::Setup> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name: None,
encodings_transferred: false,
},
})
}
@@ -355,238 +302,42 @@ impl Verifier<state::Committed> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name,
encodings_transferred,
..
} = &mut self.state;
let ProvePayload {
server_identity,
transcript,
transcript_commit,
} = mux_fut
let payload: ProveRequest = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = WebPkiVerifier::new(self.config.root_store().clone(), None);
let server_name = if let Some((name, cert_data)) = server_identity {
cert_data
.verify(
&verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
if let Some(partial_transcript) = &transcript {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Check ranges.
if partial_transcript.len_sent() != sent_len
|| partial_transcript.len_received() != recv_len
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
decode_transcript(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(VerifierError::zk)?;
}
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit {
if commit_config.encoding() {
let commitment = mux_fut
.poll_with(encoding::transfer(
ctx,
transcript_refs,
delta,
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
))
.await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if commit_config.has_hash() {
hash_commitments = Some(
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
.await?;
// Verify revealed data.
if let Some(partial_transcript) = &transcript {
verify_transcript(vm, partial_transcript, transcript_refs)
.map_err(VerifierError::verify)?;
}
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript,
transcript_commitments,
})
}
/// Attests to the TLS session.
///
/// # Arguments
///
/// * `config` - Attestation configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize(
&mut self,
config: &AttestationConfig,
) -> Result<Attestation, VerifierError> {
#[allow(deprecated)]
self.notarize_with_provider(config, &CryptoProvider::default())
.await
}
/// Attests to the TLS session.
///
/// # Arguments
///
/// * `config` - Attestation configuration.
/// * `provider` - Cryptography provider.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
#[deprecated(
note = "attestation functionality will be removed from this API in future releases."
)]
pub async fn notarize_with_provider(
&mut self,
config: &AttestationConfig,
provider: &CryptoProvider,
) -> Result<Attestation, VerifierError> {
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
} = self.verify(&VerifyConfig::default()).await?;
if server_name.is_some() {
return Err(VerifierError::attestation(
"server name can not be revealed to a verifier",
));
} else if transcript.is_some() {
return Err(VerifierError::attestation(
"transcript data can not be revealed to a verifier",
));
}
let state::Committed {
mux_fut,
ctx,
let proving_state = ProvingState::for_verifier(
payload,
tls_transcript,
..
} = &mut self.state;
transcript_refs,
verified_server_name.clone(),
*encodings_transferred,
);
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let request: Request = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
let (output, encodings_executed) = mux_fut
.poll_with(proving_state.verify(
vm,
ctx,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys.clone(),
*delta,
self.config.root_store(),
))
.await?;
let mut builder = Attestation::builder(config)
.accept_request(request)
.map_err(VerifierError::attestation)?;
*verified_server_name = output.server_name.clone();
*encodings_transferred = encodings_executed;
builder
.connection_info(ConnectionInfo {
time: tls_transcript.time(),
version: (*tls_transcript.version()),
transcript_length: TranscriptLength {
sent: sent_len as u32,
received: recv_len as u32,
},
})
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments);
let attestation = builder
.build(provider)
.map_err(VerifierError::attestation)?;
mux_fut
.poll_with(
ctx.io_mut()
.send(attestation.clone())
.map_err(VerifierError::from),
)
.await?;
info!("Sent attestation");
Ok(attestation)
Ok(output)
}
/// Closes the connection with the prover.
@@ -637,7 +388,7 @@ fn build_mpc_tls(
let mpc = Mpc::new(mpz_ot::cot::DerandCOTReceiver::new(rcot_recv.clone()));
let zk = Zk::new(delta, rcot_send.clone());
let zk = Zk::new(ZkVerifierConfig::default(), delta, rcot_send.clone());
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));

View File

@@ -1,17 +1,19 @@
use std::fmt::{Debug, Formatter, Result};
use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
use mpc_tls::Config;
use tls_core::anchors::{OwnedTrustAnchor, RootCertStore};
use serde::{Deserialize, Serialize};
use tlsn_core::webpki::RootCertStore;
use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
/// Configuration for the [`Verifier`](crate::tls::Verifier).
#[allow(missing_docs)]
#[derive(derive_builder::Builder)]
#[derive(derive_builder::Builder, Serialize, Deserialize)]
#[builder(pattern = "owned")]
pub struct VerifierConfig {
protocol_config_validator: ProtocolConfigValidator,
#[builder(default = "default_root_store()")]
root_store: RootCertStore,
#[builder(setter(strip_option))]
root_store: Option<RootCertStore>,
}
impl Debug for VerifierConfig {
@@ -34,8 +36,8 @@ impl VerifierConfig {
}
/// Returns the root certificate store.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
pub fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
pub(crate) fn build_mpc_tls_config(&self, protocol_config: &ProtocolConfig) -> Config {
@@ -61,16 +63,3 @@ impl VerifierConfig {
builder.build().unwrap()
}
}
fn default_root_store() -> RootCertStore {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject.as_ref(),
ta.subject_public_key_info.as_ref(),
ta.name_constraints.as_ref().map(|nc| nc.as_ref()),
)
}));
root_store
}

View File

@@ -1,4 +1,4 @@
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
@@ -33,20 +33,6 @@ impl VerifierError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn attestation<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Attestation, source)
}
pub(crate) fn verify<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Verify, source)
}
}
#[derive(Debug)]
@@ -56,8 +42,6 @@ enum ErrorKind {
Mpc,
Zk,
Commit,
Attestation,
Verify,
}
impl fmt::Display for VerifierError {
@@ -70,8 +54,6 @@ impl fmt::Display for VerifierError {
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Attestation => f.write_str("attestation error")?,
ErrorKind::Verify => f.write_str("verification error")?,
}
if let Some(source) = &self.source {
@@ -118,8 +100,8 @@ impl From<ZkAesCtrError> for VerifierError {
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for VerifierError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -3,14 +3,14 @@
use std::sync::Arc;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_memory_core::correlated::Delta;
use tlsn_core::transcript::TlsTranscript;
use tlsn_core::{connection::ServerName, transcript::TlsTranscript};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -45,6 +45,11 @@ pub struct Committed {
pub(crate) vm: Zk,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) verified_server_name: Option<ServerName>,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -37,8 +37,8 @@ impl ZkAesCtr {
}
/// Returns the role.
pub(crate) fn role(&self) -> &Role {
&self.role
pub(crate) fn role(&self) -> Role {
self.role
}
/// Allocates `len` bytes for encryption.

View File

@@ -1,6 +1,7 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use tlsn::{
config::{ProtocolConfig, ProtocolConfigValidator},
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{TranscriptCommitConfig, TranscriptCommitment},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
@@ -37,19 +38,17 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let server_task = tokio::spawn(bind(server_socket.compat()));
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(root_store);
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(SERVER_DOMAIN)
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
@@ -104,17 +103,12 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(socket: T) {
let mut root_store = tls_core::anchors::RootCertStore::empty();
root_store
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
.unwrap();
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
@@ -123,7 +117,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(root_store)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap(),
@@ -140,7 +136,9 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(soc
let transcript = transcript.unwrap();
assert_eq!(server_name.unwrap().as_str(), SERVER_DOMAIN);
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!transcript.is_complete());
assert_eq!(
transcript.sent_authed().iter_ranges().next().unwrap(),

Some files were not shown because too many files have changed in this diff Show More