Compare commits

...

22 Commits

Author SHA1 Message Date
Hendrik Eeckhaut
b76775fc7c correction + legend placement 2025-12-23 15:11:35 +01:00
Hendrik Eeckhaut
72041d1f07 export dark svg 2025-12-23 14:47:19 +01:00
Hendrik Eeckhaut
ac1df8fc75 Allow plotting multiple data runs 2025-12-23 14:31:54 +01:00
Hendrik Eeckhaut
3cb7c5c0b4 Working on benchmark plots 2025-12-23 14:07:39 +01:00
Hendrik Eeckhaut
b41d678829 build: update Rust to version 1.92.0 2025-12-16 09:36:11 +01:00
sinu.eth
1ebefa27d8 perf(core): fold instead of flatten (#1064) 2025-12-11 06:41:26 -08:00
dan
4fe5c1defd feat(harness): add reveal_all config (#1063) 2025-12-09 12:01:39 +00:00
dan
0e8e547300 chore: adapt for rangeset 0.4 (#1058) 2025-12-09 11:36:13 +00:00
dan
22cc88907a chore: bump mpz (#1057) 2025-12-04 10:27:43 +00:00
Hendrik Eeckhaut
cec4756e0e ci: set GITHUB_TOKEN env 2025-11-28 14:33:19 +01:00
Hendrik Eeckhaut
0919e1f2b3 clippy: allow deprecated aead::generic_array 2025-11-28 14:33:19 +01:00
Hendrik Eeckhaut
43b9f57e1f build: update Rust to version 1.91.1 2025-11-26 16:53:08 +01:00
dan
c51331d63d test: use ideal vm for testing (#1049) 2025-11-07 12:56:09 +00:00
dan
3905d9351c chore: clean up deps (#1048) 2025-11-07 10:36:41 +00:00
dan
f8a67bc8e7 feat(core): support proving keccak256 commitments (#1046) 2025-11-07 09:18:44 +00:00
dan
952a7011bf feat(cipher): use AES pre/post key schedule circuits (#1042) 2025-11-07 09:08:08 +00:00
Ram
0673818e4e chore: fix links to key exchange doc page (#1045) 2025-11-04 23:07:58 +01:00
dan
a5749d81f1 fix(attestation): verify sig during validation (#1037) 2025-10-30 07:59:57 +00:00
sinu.eth
f2e119bb66 refactor: move and rewrite configuration (#1034)
* refactor: move and rewrite configuration

* fix wasm
2025-10-27 11:47:42 -07:00
Hendrik Eeckhaut
271ac3771e Fix example (#1033)
* fix: provide encoder secret to attestation
* Add missing entry in example's README file
2025-10-24 10:33:32 +02:00
Benjamin Martinez Picech
f69dd7a239 refactor(tlsn-core): redeclaration of content type into core (#1026)
* redeclaration of content type into core

* fix compilation error

* comment removal

* Lint and format fixes

* fix wasm build

* Unknown content type

* format fix
2025-10-23 15:47:53 +02:00
sinu.eth
79f5160cae feat(tlsn): insecure mode (#1031) 2025-10-22 10:18:11 -07:00
99 changed files with 5138 additions and 2241 deletions

View File

@@ -21,7 +21,8 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178 # - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future # 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32 RAYON_NUM_THREADS: 32
RUST_VERSION: 1.90.0 RUST_VERSION: 1.92.0
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs: jobs:
clippy: clippy:

2561
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,26 +66,27 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" } tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" } tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" } mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
rangeset = { version = "0.2" } rangeset = { version = "0.4" }
serio = { version = "0.2" } serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
uid-mux = { version = "0.2" } uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
aead = { version = "0.4" } aead = { version = "0.4" }
aes = { version = "0.8" } aes = { version = "0.8" }

View File

@@ -9,7 +9,7 @@ fixtures = ["tlsn-core/fixtures", "dep:tlsn-data-fixtures"]
[dependencies] [dependencies]
tlsn-tls-core = { workspace = true } tlsn-tls-core = { workspace = true }
tlsn-core = { workspace = true } tlsn-core = { workspace = true, features = ["mozilla-certs"] }
tlsn-data-fixtures = { workspace = true, optional = true } tlsn-data-fixtures = { workspace = true, optional = true }
bcs = { workspace = true } bcs = { workspace = true }

View File

@@ -1,5 +1,4 @@
//! Attestation fixtures. //! Attestation fixtures.
use tlsn_core::{ use tlsn_core::{
connection::{CertBinding, CertBindingV1_2}, connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture, fixtures::ConnectionFixture,
@@ -13,7 +12,10 @@ use tlsn_core::{
use crate::{ use crate::{
Attestation, AttestationConfig, CryptoProvider, Extension, Attestation, AttestationConfig, CryptoProvider, Extension,
request::{Request, RequestConfig}, request::{Request, RequestConfig},
signing::SignatureAlgId, signing::{
KeyAlgId, SignatureAlgId, SignatureVerifier, SignatureVerifierProvider, Signer,
SignerProvider,
},
}; };
/// A Request fixture used for testing. /// A Request fixture used for testing.
@@ -102,7 +104,8 @@ pub fn attestation_fixture(
let mut provider = CryptoProvider::default(); let mut provider = CryptoProvider::default();
match signature_alg { match signature_alg {
SignatureAlgId::SECP256K1 => provider.signer.set_secp256k1(&[42u8; 32]).unwrap(), SignatureAlgId::SECP256K1 => provider.signer.set_secp256k1(&[42u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[42u8; 32]).unwrap(), SignatureAlgId::SECP256K1ETH => provider.signer.set_secp256k1eth(&[43u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[44u8; 32]).unwrap(),
_ => unimplemented!(), _ => unimplemented!(),
}; };
@@ -122,3 +125,68 @@ pub fn attestation_fixture(
attestation_builder.build(&provider).unwrap() attestation_builder.build(&provider).unwrap()
} }
/// Returns a crypto provider which supports only a custom signature alg.
pub fn custom_provider_fixture() -> CryptoProvider {
const CUSTOM_SIG_ALG_ID: SignatureAlgId = SignatureAlgId::new(128);
// A dummy signer.
struct DummySigner {}
impl Signer for DummySigner {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn sign(
&self,
msg: &[u8],
) -> Result<crate::signing::Signature, crate::signing::SignatureError> {
Ok(crate::signing::Signature {
alg: CUSTOM_SIG_ALG_ID,
data: msg.to_vec(),
})
}
fn verifying_key(&self) -> crate::signing::VerifyingKey {
crate::signing::VerifyingKey {
alg: KeyAlgId::new(128),
data: vec![1, 2, 3, 4],
}
}
}
// A dummy verifier.
struct DummyVerifier {}
impl SignatureVerifier for DummyVerifier {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn verify(
&self,
_key: &crate::signing::VerifyingKey,
msg: &[u8],
sig: &[u8],
) -> Result<(), crate::signing::SignatureError> {
if msg == sig {
Ok(())
} else {
Err(crate::signing::SignatureError::from_str(
"invalid signature",
))
}
}
}
let mut provider = CryptoProvider::default();
let mut signer_provider = SignerProvider::default();
signer_provider.set_signer(Box::new(DummySigner {}));
provider.signer = signer_provider;
let mut verifier_provider = SignatureVerifierProvider::empty();
verifier_provider.set_verifier(Box::new(DummyVerifier {}));
provider.signature = verifier_provider;
provider
}

View File

@@ -20,7 +20,10 @@ use serde::{Deserialize, Serialize};
use tlsn_core::hash::HashAlgId; use tlsn_core::hash::HashAlgId;
use crate::{Attestation, Extension, connection::ServerCertCommitment, signing::SignatureAlgId}; use crate::{
Attestation, CryptoProvider, Extension, connection::ServerCertCommitment,
serialize::CanonicalSerialize, signing::SignatureAlgId,
};
pub use builder::{RequestBuilder, RequestBuilderError}; pub use builder::{RequestBuilder, RequestBuilderError};
pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError}; pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError};
@@ -41,44 +44,102 @@ impl Request {
} }
/// Validates the content of the attestation against this request. /// Validates the content of the attestation against this request.
pub fn validate(&self, attestation: &Attestation) -> Result<(), InconsistentAttestation> { pub fn validate(
&self,
attestation: &Attestation,
provider: &CryptoProvider,
) -> Result<(), AttestationValidationError> {
if attestation.signature.alg != self.signature_alg { if attestation.signature.alg != self.signature_alg {
return Err(InconsistentAttestation(format!( return Err(AttestationValidationError::inconsistent(format!(
"signature algorithm: expected {:?}, got {:?}", "signature algorithm: expected {:?}, got {:?}",
self.signature_alg, attestation.signature.alg self.signature_alg, attestation.signature.alg
))); )));
} }
if attestation.header.root.alg != self.hash_alg { if attestation.header.root.alg != self.hash_alg {
return Err(InconsistentAttestation(format!( return Err(AttestationValidationError::inconsistent(format!(
"hash algorithm: expected {:?}, got {:?}", "hash algorithm: expected {:?}, got {:?}",
self.hash_alg, attestation.header.root.alg self.hash_alg, attestation.header.root.alg
))); )));
} }
if attestation.body.cert_commitment() != &self.server_cert_commitment { if attestation.body.cert_commitment() != &self.server_cert_commitment {
return Err(InconsistentAttestation( return Err(AttestationValidationError::inconsistent(
"server certificate commitment does not match".to_string(), "server certificate commitment does not match",
)); ));
} }
// TODO: improve the O(M*N) complexity of this check. // TODO: improve the O(M*N) complexity of this check.
for extension in &self.extensions { for extension in &self.extensions {
if !attestation.body.extensions().any(|e| e == extension) { if !attestation.body.extensions().any(|e| e == extension) {
return Err(InconsistentAttestation( return Err(AttestationValidationError::inconsistent(
"extension is missing from the attestation".to_string(), "extension is missing from the attestation",
)); ));
} }
} }
let verifier = provider
.signature
.get(&attestation.signature.alg)
.map_err(|_| {
AttestationValidationError::provider(format!(
"provider not configured for signature algorithm id {:?}",
attestation.signature.alg,
))
})?;
verifier
.verify(
&attestation.body.verifying_key.data,
&CanonicalSerialize::serialize(&attestation.header),
&attestation.signature.data,
)
.map_err(|_| {
AttestationValidationError::inconsistent("failed to verify the signature")
})?;
Ok(()) Ok(())
} }
} }
/// Error for [`Request::validate`]. /// Error for [`Request::validate`].
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("inconsistent attestation: {0}")] #[error("attestation validation error: {kind}: {message}")]
pub struct InconsistentAttestation(String); pub struct AttestationValidationError {
kind: ErrorKind,
message: String,
}
impl AttestationValidationError {
fn inconsistent(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Inconsistent,
message: msg.into(),
}
}
fn provider(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Provider,
message: msg.into(),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Inconsistent,
Provider,
}
impl std::fmt::Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorKind::Inconsistent => write!(f, "inconsistent"),
ErrorKind::Provider => write!(f, "provider"),
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
@@ -93,7 +154,8 @@ mod test {
use crate::{ use crate::{
CryptoProvider, CryptoProvider,
connection::ServerCertOpening, connection::ServerCertOpening,
fixtures::{RequestFixture, attestation_fixture, request_fixture}, fixtures::{RequestFixture, attestation_fixture, custom_provider_fixture, request_fixture},
request::{AttestationValidationError, ErrorKind},
signing::SignatureAlgId, signing::SignatureAlgId,
}; };
@@ -113,7 +175,9 @@ mod test {
let attestation = let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]); attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
assert!(request.validate(&attestation).is_ok()) let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_ok())
} }
#[test] #[test]
@@ -134,7 +198,9 @@ mod test {
request.signature_alg = SignatureAlgId::SECP256R1; request.signature_alg = SignatureAlgId::SECP256R1;
let res = request.validate(&attestation); let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
assert!(res.is_err()); assert!(res.is_err());
} }
@@ -156,7 +222,9 @@ mod test {
request.hash_alg = HashAlgId::SHA256; request.hash_alg = HashAlgId::SHA256;
let res = request.validate(&attestation); let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
assert!(res.is_err()) assert!(res.is_err())
} }
@@ -184,11 +252,62 @@ mod test {
}); });
let opening = ServerCertOpening::new(server_cert_data); let opening = ServerCertOpening::new(server_cert_data);
let crypto_provider = CryptoProvider::default(); let provider = CryptoProvider::default();
request.server_cert_commitment = request.server_cert_commitment =
opening.commit(crypto_provider.hash.get(&HashAlgId::BLAKE3).unwrap()); opening.commit(provider.hash.get(&HashAlgId::BLAKE3).unwrap());
let res = request.validate(&attestation); let res = request.validate(&attestation, &provider);
assert!(res.is_err()) assert!(res.is_err())
} }
#[test]
fn test_wrong_sig() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
// Corrupt the signature.
attestation.signature.data[1] = attestation.signature.data[1].wrapping_add(1);
let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_err())
}
#[test]
fn test_wrong_provider() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
let provider = custom_provider_fixture();
assert!(matches!(
request.validate(&attestation, &provider),
Err(AttestationValidationError {
kind: ErrorKind::Provider,
..
})
))
}
} }

View File

@@ -202,6 +202,14 @@ impl SignatureVerifierProvider {
.map(|s| &**s) .map(|s| &**s)
.ok_or(UnknownSignatureAlgId(*alg)) .ok_or(UnknownSignatureAlgId(*alg))
} }
/// Returns am empty provider.
#[cfg(any(test, feature = "fixtures"))]
pub fn empty() -> Self {
Self {
verifiers: HashMap::default(),
}
}
} }
/// Signature verifier. /// Signature verifier.
@@ -229,6 +237,14 @@ impl_domain_separator!(VerifyingKey);
#[error("signature verification failed: {0}")] #[error("signature verification failed: {0}")]
pub struct SignatureError(String); pub struct SignatureError(String);
impl SignatureError {
/// Creates a new error with the given message.
#[allow(clippy::should_implement_trait)]
pub fn from_str(msg: &str) -> Self {
Self(msg.to_string())
}
}
/// A signature. /// A signature.
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Signature { pub struct Signature {

View File

@@ -101,7 +101,7 @@ fn test_api() {
let attestation = attestation_builder.build(&provider).unwrap(); let attestation = attestation_builder.build(&provider).unwrap();
// Prover validates the attestation is consistent with its request. // Prover validates the attestation is consistent with its request.
request.validate(&attestation).unwrap(); request.validate(&attestation, &provider).unwrap();
let mut transcript_proof_builder = secrets.transcript_proof_builder(); let mut transcript_proof_builder = secrets.transcript_proof_builder();

View File

@@ -15,7 +15,7 @@ workspace = true
name = "cipher" name = "cipher"
[dependencies] [dependencies]
mpz-circuits = { workspace = true } mpz-circuits = { workspace = true, features = ["aes"] }
mpz-vm-core = { workspace = true } mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true } mpz-memory-core = { workspace = true }
@@ -24,11 +24,9 @@ thiserror = { workspace = true }
aes = { workspace = true } aes = { workspace = true }
[dev-dependencies] [dev-dependencies]
mpz-garble = { workspace = true } mpz-common = { workspace = true, features = ["test-utils"] }
mpz-common = { workspace = true } mpz-ideal-vm = { workspace = true }
mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true } ctr = { workspace = true }
cipher = { workspace = true } cipher = { workspace = true }

View File

@@ -2,7 +2,7 @@
use crate::{Cipher, CtrBlock, Keystream}; use crate::{Cipher, CtrBlock, Keystream};
use async_trait::async_trait; use async_trait::async_trait;
use mpz_circuits::circuits::AES128; use mpz_circuits::{AES128_KS, AES128_POST_KS};
use mpz_memory_core::binary::{Binary, U8}; use mpz_memory_core::binary::{Binary, U8};
use mpz_vm_core::{prelude::*, Call, Vm}; use mpz_vm_core::{prelude::*, Call, Vm};
use std::fmt::Debug; use std::fmt::Debug;
@@ -12,13 +12,35 @@ mod error;
pub use error::AesError; pub use error::AesError;
use error::ErrorKind; use error::ErrorKind;
/// AES key schedule: 11 round keys, 16 bytes each.
type KeySchedule = Array<U8, 176>;
/// Computes AES-128. /// Computes AES-128.
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct Aes128 { pub struct Aes128 {
key: Option<Array<U8, 16>>, key: Option<Array<U8, 16>>,
key_schedule: Option<KeySchedule>,
iv: Option<Array<U8, 4>>, iv: Option<Array<U8, 4>>,
} }
impl Aes128 {
// Allocates key schedule.
//
// Expects the key to be already set.
fn alloc_key_schedule(&self, vm: &mut dyn Vm<Binary>) -> Result<KeySchedule, AesError> {
let ks: KeySchedule = vm
.call(
Call::builder(AES128_KS.clone())
.arg(self.key.expect("key is set"))
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(ks)
}
}
#[async_trait] #[async_trait]
impl Cipher for Aes128 { impl Cipher for Aes128 {
type Error = AesError; type Error = AesError;
@@ -45,18 +67,22 @@ impl Cipher for Aes128 {
} }
fn alloc_block( fn alloc_block(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
input: Array<U8, 16>, input: Array<U8, 16>,
) -> Result<Self::Block, Self::Error> { ) -> Result<Self::Block, Self::Error> {
let key = self self.key
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?; .ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm let output = vm
.call( .call(
Call::builder(AES128.clone()) Call::builder(AES128_POST_KS.clone())
.arg(key) .arg(ks)
.arg(input) .arg(input)
.build() .build()
.expect("call should be valid"), .expect("call should be valid"),
@@ -67,11 +93,10 @@ impl Cipher for Aes128 {
} }
fn alloc_ctr_block( fn alloc_ctr_block(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> { ) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self self.key
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?; .ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self let iv = self
.iv .iv
@@ -89,10 +114,15 @@ impl Cipher for Aes128 {
vm.mark_public(counter) vm.mark_public(counter)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?; .map_err(|err| AesError::new(ErrorKind::Vm, err))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm let output = vm
.call( .call(
Call::builder(AES128.clone()) Call::builder(AES128_POST_KS.clone())
.arg(key) .arg(ks)
.arg(iv) .arg(iv)
.arg(explicit_nonce) .arg(explicit_nonce)
.arg(counter) .arg(counter)
@@ -109,12 +139,11 @@ impl Cipher for Aes128 {
} }
fn alloc_keystream( fn alloc_keystream(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
len: usize, len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> { ) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self self.key
.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?; .ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self let iv = self
.iv .iv
@@ -143,10 +172,15 @@ impl Cipher for Aes128 {
let blocks = inputs let blocks = inputs
.into_iter() .into_iter()
.map(|(explicit_nonce, counter)| { .map(|(explicit_nonce, counter)| {
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm let output = vm
.call( .call(
Call::builder(AES128.clone()) Call::builder(AES128_POST_KS.clone())
.arg(key) .arg(ks)
.arg(iv) .arg(iv)
.arg(explicit_nonce) .arg(explicit_nonce)
.arg(counter) .arg(counter)
@@ -172,15 +206,12 @@ mod tests {
use super::*; use super::*;
use crate::Cipher; use crate::Cipher;
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_memory_core::{ use mpz_memory_core::{
binary::{Binary, U8}, binary::{Binary, U8},
correlated::Delta,
Array, MemoryExt, Vector, ViewExt, Array, MemoryExt, Vector, ViewExt,
}; };
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{Execute, Vm}; use mpz_vm_core::{Execute, Vm};
use rand::{rngs::StdRng, SeedableRng};
#[tokio::test] #[tokio::test]
async fn test_aes_ctr() { async fn test_aes_ctr() {
@@ -190,10 +221,11 @@ mod tests {
let start_counter = 3u32; let start_counter = 3u32;
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm(); let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let aes_gen = setup_ctr(key, iv, &mut gen); let mut aes_gen = setup_ctr(key, iv, &mut gen);
let aes_ev = setup_ctr(key, iv, &mut ev); let mut aes_ev = setup_ctr(key, iv, &mut ev);
let msg = vec![42u8; 128]; let msg = vec![42u8; 128];
@@ -252,10 +284,11 @@ mod tests {
let input = [5_u8; 16]; let input = [5_u8; 16];
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm(); let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let aes_gen = setup_block(key, &mut gen); let mut aes_gen = setup_block(key, &mut gen);
let aes_ev = setup_block(key, &mut ev); let mut aes_ev = setup_block(key, &mut ev);
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap(); let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
gen.mark_public(block_ref_gen).unwrap(); gen.mark_public(block_ref_gen).unwrap();
@@ -294,18 +327,6 @@ mod tests {
assert_eq!(ciphertext_gen, expected); assert_eq!(ciphertext_gen, expected);
} }
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 { fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
let key_ref: Array<U8, 16> = vm.alloc().unwrap(); let key_ref: Array<U8, 16> = vm.alloc().unwrap();
vm.mark_public(key_ref).unwrap(); vm.mark_public(key_ref).unwrap();

View File

@@ -55,7 +55,7 @@ pub trait Cipher {
/// Allocates a single block in ECB mode. /// Allocates a single block in ECB mode.
fn alloc_block( fn alloc_block(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
input: Self::Block, input: Self::Block,
) -> Result<Self::Block, Self::Error>; ) -> Result<Self::Block, Self::Error>;
@@ -63,7 +63,7 @@ pub trait Cipher {
/// Allocates a single block in counter mode. /// Allocates a single block in counter mode.
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn alloc_ctr_block( fn alloc_ctr_block(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>; ) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
@@ -75,7 +75,7 @@ pub trait Cipher {
/// * `len` - Length of the stream in bytes. /// * `len` - Length of the stream in bytes.
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn alloc_keystream( fn alloc_keystream(
&self, &mut self,
vm: &mut dyn Vm<Binary>, vm: &mut dyn Vm<Binary>,
len: usize, len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>; ) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;

View File

@@ -19,11 +19,8 @@ futures = { workspace = true }
tokio = { workspace = true, features = ["sync"] } tokio = { workspace = true, features = ["sync"] }
[dev-dependencies] [dev-dependencies]
mpz-circuits = { workspace = true } mpz-circuits = { workspace = true, features = ["aes"] }
mpz-garble = { workspace = true } mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true } mpz-ideal-vm = { workspace = true }
mpz-zk = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
rand06-compat = { workspace = true }

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View}, memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError, Call, Callable, Execute, Vm, VmError,
}; };
use rangeset::{Difference, RangeSet, UnionMut}; use rangeset::{ops::Set, set::RangeSet};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard}; use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError; type Error = DeapError;
@@ -210,10 +210,12 @@ where
} }
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> { fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization. // Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges); let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap(); let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() { for input in input_minus_follower {
zk.commit_raw( zk.commit_raw(
self.memory_map self.memory_map
.try_get(Slice::from_range_unchecked(input))?, .try_get(Slice::from_range_unchecked(input))?,
@@ -266,7 +268,7 @@ where
mpc.mark_private_raw(slice)?; mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization. // Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?; zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range()); self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice); self.follower_inputs.push(slice);
} }
} }
@@ -282,7 +284,7 @@ where
mpc.mark_blind_raw(slice)?; mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization. // Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?; zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range()); self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice); self.follower_inputs.push(slice);
} }
Role::Follower => { Role::Follower => {
@@ -382,37 +384,27 @@ enum ErrorRepr {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mpz_circuits::circuits::AES128; use mpz_circuits::AES128;
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_core::Block; use mpz_ideal_vm::IdealVm;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
use mpz_vm_core::{ use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array}, memory::{binary::U8, Array},
prelude::*, prelude::*,
}; };
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{rngs::StdRng, SeedableRng};
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn test_deap() { async fn test_deap() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc); let leader_mpc = IdealVm::new();
let ev = Evaluator::new(cot_recv); let leader_zk = IdealVm::new();
let prover = Prover::new(ProverConfig::default(), rcot_recv); let follower_mpc = IdealVm::new();
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send); let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover); let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, ev, verifier); let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let (ct_leader, ct_follower) = futures::join!( let (ct_leader, ct_follower) = futures::join!(
async { async {
@@ -478,21 +470,15 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_deap_desync_memory() { async fn test_deap_desync_memory() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc); let leader_mpc = IdealVm::new();
let ev = Evaluator::new(cot_recv); let leader_zk = IdealVm::new();
let prover = Prover::new(ProverConfig::default(), rcot_recv); let follower_mpc = IdealVm::new();
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send); let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover); let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, ev, verifier); let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
// Desynchronize the memories. // Desynchronize the memories.
let _ = leader.zk().alloc_raw(1).unwrap(); let _ = leader.zk().alloc_raw(1).unwrap();
@@ -564,21 +550,15 @@ mod tests {
// detection by the follower. // detection by the follower.
#[tokio::test] #[tokio::test]
async fn test_malicious() { async fn test_malicious() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc); let leader_mpc = IdealVm::new();
let ev = Evaluator::new(cot_recv); let leader_zk = IdealVm::new();
let prover = Prover::new(ProverConfig::default(), rcot_recv); let follower_mpc = IdealVm::new();
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send); let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover); let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, ev, verifier); let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let (_, follower_res) = futures::join!( let (_, follower_res) = futures::join!(
async { async {

View File

@@ -1,7 +1,7 @@
use std::ops::Range; use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError}; use mpz_vm_core::{memory::Slice, VmError};
use rangeset::Subset; use rangeset::ops::Set;
/// A mapping between the memories of the MPC and ZK VMs. /// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)] #[derive(Debug, Default)]

View File

@@ -20,14 +20,13 @@ mpz-core = { workspace = true }
mpz-circuits = { workspace = true } mpz-circuits = { workspace = true }
mpz-hash = { workspace = true } mpz-hash = { workspace = true }
sha2 = { workspace = true, features = ["compress"] }
thiserror = { workspace = true } thiserror = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
sha2 = { workspace = true }
[dev-dependencies] [dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] } mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
criterion = { workspace = true, features = ["async_tokio"] } criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -4,14 +4,12 @@ use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{Mode, MpcPrf}; use hmac_sha256::{Mode, MpcPrf};
use mpz_common::context::test_mt_context; use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{ use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array}, memory::{binary::U8, Array},
prelude::*, prelude::*,
Execute, Execute,
}; };
use rand::{rngs::StdRng, SeedableRng};
#[allow(clippy::unit_arg)] #[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {
@@ -29,8 +27,6 @@ criterion_group!(benches, criterion_benchmark);
criterion_main!(benches); criterion_main!(benches);
async fn prf(mode: Mode) { async fn prf(mode: Mode) {
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32]; let pms = [42u8; 32];
let client_random = [69u8; 32]; let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32]; let server_random: [u8; 32] = [96u8; 32];
@@ -39,11 +35,8 @@ async fn prf(mode: Mode) {
let mut leader_ctx = leader_exec.new_context().await.unwrap(); let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap(); let mut follower_ctx = follower_exec.new_context().await.unwrap();
let delta = Delta::random(&mut rng); let mut leader_vm = IdealVm::new();
let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); let mut follower_vm = IdealVm::new();
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap(); let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap(); leader_vm.mark_public(leader_pms).unwrap();

View File

@@ -54,10 +54,11 @@ mod tests {
use crate::{ use crate::{
hmac::hmac_sha256, hmac::hmac_sha256,
sha256, state_to_bytes, sha256, state_to_bytes,
test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, test_utils::{compute_inner_local, compute_outer_partial},
}; };
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_hash::sha256::Sha256; use mpz_hash::sha256::Sha256;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{ use mpz_vm_core::{
memory::{ memory::{
binary::{U32, U8}, binary::{U32, U8},
@@ -83,7 +84,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_hmac_circuit() { async fn test_hmac_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm(); let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let (inputs, references) = test_fixtures(); let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) { for (input, &reference) in inputs.iter().zip(references.iter()) {

View File

@@ -72,10 +72,11 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, test_utils::{prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
Mode, MpcPrf, SessionKeys, Mode, MpcPrf, SessionKeys,
}; };
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{ use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt}, memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute, Execute,
@@ -123,7 +124,8 @@ mod tests {
// Set up vm and prf // Set up vm and prf
let (mut ctx_a, mut ctx_b) = test_st_context(128); let (mut ctx_a, mut ctx_b) = test_st_context(128);
let (mut leader, mut follower) = mock_vm(); let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let leader_pms: Array<U8, 32> = leader.alloc().unwrap(); let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_pms).unwrap(); leader.mark_public(leader_pms).unwrap();

View File

@@ -339,8 +339,9 @@ fn gen_merge_circ(size: usize) -> Arc<Circuit> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{prf::merge_outputs, test_utils::mock_vm}; use crate::prf::merge_outputs;
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{ use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt}, memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute, Execute,
@@ -349,7 +350,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_merge_outputs() { async fn test_merge_outputs() {
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm(); let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let input1: [u8; 32] = std::array::from_fn(|i| i as u8); let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32); let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);

View File

@@ -137,10 +137,11 @@ impl Prf {
mod tests { mod tests {
use crate::{ use crate::{
prf::{compute_partial, function::Prf}, prf::{compute_partial, function::Prf},
test_utils::{mock_vm, phash}, test_utils::phash,
Mode, Mode,
}; };
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{ use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt}, memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute, Execute,
@@ -166,7 +167,8 @@ mod tests {
let mut rng = ThreadRng::default(); let mut rng = ThreadRng::default();
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm(); let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let key: [u8; 32] = rng.random(); let key: [u8; 32] = rng.random();
let start_seed: Vec<u8> = vec![42; 64]; let start_seed: Vec<u8> = vec![42; 64];

View File

@@ -1,25 +1,10 @@
use crate::{sha256, state_to_bytes}; use crate::{sha256, state_to_bytes};
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_vm_core::memory::correlated::Delta;
use rand::{rngs::StdRng, Rng, SeedableRng}; use rand::{rngs::StdRng, Rng, SeedableRng};
pub(crate) const SHA256_IV: [u32; 8] = [ pub(crate) const SHA256_IV: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
]; ];
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] { pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
let mut label_start_seed = b"master secret".to_vec(); let mut label_start_seed = b"master secret".to_vec();
label_start_seed.extend_from_slice(&client_random); label_start_seed.extend_from_slice(&client_random);

View File

@@ -40,6 +40,7 @@ tokio = { workspace = true, features = ["sync"] }
[dev-dependencies] [dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] } mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true } mpz-garble = { workspace = true }
mpz-ideal-vm = { workspace = true }
rand_core = { workspace = true } rand_core = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -459,9 +459,7 @@ mod tests {
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_core::Block; use mpz_core::Block;
use mpz_fields::UniformRand; use mpz_fields::UniformRand;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_share_conversion::ideal::{ use mpz_share_conversion::ideal::{
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender, ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
}; };
@@ -484,7 +482,8 @@ mod tests {
async fn test_key_exchange() { async fn test_key_exchange() {
let mut rng = StdRng::seed_from_u64(0).compat(); let mut rng = StdRng::seed_from_u64(0).compat();
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm(); let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let leader_private_key = SecretKey::random(&mut rng); let leader_private_key = SecretKey::random(&mut rng);
let follower_private_key = SecretKey::random(&mut rng); let follower_private_key = SecretKey::random(&mut rng);
@@ -625,7 +624,8 @@ mod tests {
async fn test_malicious_key_exchange(#[case] malicious: Malicious) { async fn test_malicious_key_exchange(#[case] malicious: Malicious) {
let mut rng = StdRng::seed_from_u64(0); let mut rng = StdRng::seed_from_u64(0);
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm(); let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let leader_private_key = SecretKey::random(&mut rng.compat_by_ref()); let leader_private_key = SecretKey::random(&mut rng.compat_by_ref());
let follower_private_key = SecretKey::random(&mut rng.compat_by_ref()); let follower_private_key = SecretKey::random(&mut rng.compat_by_ref());
@@ -704,7 +704,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_circuit() { async fn test_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (gen, ev) = mock_vm(); let gen = IdealVm::new();
let ev = IdealVm::new();
let share_a0_bytes = [5_u8; 32]; let share_a0_bytes = [5_u8; 32];
let share_a1_bytes = [2_u8; 32]; let share_a1_bytes = [2_u8; 32];
@@ -834,16 +835,4 @@ mod tests {
(leader, follower) (leader, follower)
} }
fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
} }

View File

@@ -8,7 +8,7 @@
//! with the server alone and forward all messages from and to the follower. //! with the server alone and forward all messages from and to the follower.
//! //!
//! A detailed description of this protocol can be found in our documentation //! A detailed description of this protocol can be found in our documentation
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>. //! <https://tlsnotary.org/docs/mpc/key_exchange>.
#![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)] #![deny(clippy::all)]

View File

@@ -26,8 +26,7 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
use super::*; use super::*;
use crate::KeyExchange; use crate::KeyExchange;
@@ -40,12 +39,12 @@ mod tests {
is_key_exchange::< is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>, MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Garbler<IdealCOTSender>, IdealVm,
>(leader); >(leader);
is_key_exchange::< is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>, MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Evaluator<IdealCOTReceiver>, IdealVm,
>(follower); >(follower);
} }
} }

View File

@@ -4,7 +4,7 @@
//! protocol has semi-honest security. //! protocol has semi-honest security.
//! //!
//! The protocol is described in //! The protocol is described in
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html> //! <https://tlsnotary.org/docs/mpc/key_exchange>
use crate::{KeyExchangeError, Role}; use crate::{KeyExchangeError, Role};
use mpz_common::{Context, Flush}; use mpz_common::{Context, Flush};

View File

@@ -13,6 +13,7 @@ workspace = true
[features] [features]
default = [] default = []
mozilla-certs = ["dep:webpki-root-certs", "dep:webpki-roots"]
fixtures = [ fixtures = [
"dep:hex", "dep:hex",
"dep:tlsn-data-fixtures", "dep:tlsn-data-fixtures",
@@ -44,7 +45,8 @@ sha2 = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] } tiny-keccak = { workspace = true, features = ["keccak"] }
web-time = { workspace = true } web-time = { workspace = true }
webpki-roots = { workspace = true } webpki-roots = { workspace = true, optional = true }
webpki-root-certs = { workspace = true, optional = true }
rustls-webpki = { workspace = true, features = ["ring"] } rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true } rustls-pki-types = { workspace = true }
itybity = { workspace = true } itybity = { workspace = true }
@@ -57,5 +59,7 @@ generic-array = { workspace = true }
bincode = { workspace = true } bincode = { workspace = true }
hex = { workspace = true } hex = { workspace = true }
rstest = { workspace = true } rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true } tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true } webpki-root-certs = { workspace = true }

View File

@@ -0,0 +1,7 @@
//! Configuration types.
pub mod prove;
pub mod prover;
pub mod tls;
pub mod tls_commit;
pub mod verifier;

View File

@@ -0,0 +1,189 @@
//! Proving configuration.
use rangeset::set::{RangeSet, ToRangeSet};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
/// Returns a request.
pub fn to_request(&self) -> ProveRequest {
ProveRequest {
server_identity: self.server_identity,
reveal: self.reveal.clone(),
transcript_commit: self
.transcript_commit
.clone()
.map(|config| config.to_request()),
}
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
}));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
}
/// Request to prove statements about the connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveRequest {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitRequest> {
self.transcript_commit.as_ref()
}
}
/// Error for [`ProveConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
}

View File

@@ -0,0 +1,33 @@
//! Prover configuration.
use serde::{Deserialize, Serialize};
/// Prover configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProverConfig {}
impl ProverConfig {
/// Creates a new builder.
pub fn builder() -> ProverConfigBuilder {
ProverConfigBuilder::default()
}
}
/// Builder for [`ProverConfig`].
#[derive(Debug, Default)]
pub struct ProverConfigBuilder {}
impl ProverConfigBuilder {
/// Builds the configuration.
pub fn build(self) -> Result<ProverConfig, ProverConfigError> {
Ok(ProverConfig {})
}
}
/// Error for [`ProverConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProverConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {}

View File

@@ -0,0 +1,111 @@
//! TLS client configuration.
use serde::{Deserialize, Serialize};
use crate::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
/// TLS client configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsClientConfig {
server_name: ServerName,
/// Root certificates.
root_store: RootCertStore,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsClientConfig {
/// Creates a new builder.
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
/// Returns the server name.
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Returns the root certificates.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> Option<&(Vec<CertificateDer>, PrivateKeyDer)> {
self.client_auth.as_ref()
}
}
/// Builder for [`TlsClientConfig`].
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
server_name: Option<ServerName>,
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
/// Sets the server name.
pub fn server_name(mut self, server_name: ServerName) -> Self {
self.server_name = Some(server_name);
self
}
/// Sets the root certificates to use for verifying the server's
/// certificate.
pub fn root_store(mut self, store: RootCertStore) -> Self {
self.root_store = Some(store);
self
}
/// Sets a DER-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> Self {
self.client_auth = Some(cert_key);
self
}
/// Builds the TLS configuration.
pub fn build(self) -> Result<TlsClientConfig, TlsConfigError> {
let server_name = self.server_name.ok_or(ErrorRepr::MissingField {
field: "server_name",
})?;
let root_store = self.root_store.ok_or(ErrorRepr::MissingField {
field: "root_store",
})?;
Ok(TlsClientConfig {
server_name,
root_store,
client_auth: self.client_auth,
})
}
}
/// TLS configuration error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error")]
enum ErrorRepr {
#[error("missing required field: {field}")]
MissingField { field: &'static str },
}

View File

@@ -0,0 +1,94 @@
//! TLS commitment configuration.
pub mod mpc;
use serde::{Deserialize, Serialize};
/// TLS commitment configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfig {
protocol: TlsCommitProtocolConfig,
}
impl TlsCommitConfig {
/// Creates a new builder.
pub fn builder() -> TlsCommitConfigBuilder {
TlsCommitConfigBuilder::default()
}
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.protocol
}
/// Returns a TLS commitment request.
pub fn to_request(&self) -> TlsCommitRequest {
TlsCommitRequest {
config: self.protocol.clone(),
}
}
}
/// Builder for [`TlsCommitConfig`].
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfigBuilder {
protocol: Option<TlsCommitProtocolConfig>,
}
impl TlsCommitConfigBuilder {
/// Sets the protocol configuration.
pub fn protocol<C>(mut self, protocol: C) -> Self
where
C: Into<TlsCommitProtocolConfig>,
{
self.protocol = Some(protocol.into());
self
}
/// Builds the configuration.
pub fn build(self) -> Result<TlsCommitConfig, TlsCommitConfigError> {
let protocol = self
.protocol
.ok_or(ErrorRepr::MissingField { name: "protocol" })?;
Ok(TlsCommitConfig { protocol })
}
}
/// TLS commitment protocol configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TlsCommitProtocolConfig {
/// MPC-TLS configuration.
Mpc(mpc::MpcTlsConfig),
}
impl From<mpc::MpcTlsConfig> for TlsCommitProtocolConfig {
fn from(config: mpc::MpcTlsConfig) -> Self {
Self::Mpc(config)
}
}
/// TLS commitment request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitRequest {
config: TlsCommitProtocolConfig,
}
impl TlsCommitRequest {
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.config
}
}
/// Error for [`TlsCommitConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsCommitConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -0,0 +1,241 @@
//! MPC-TLS commitment protocol configuration.
use serde::{Deserialize, Serialize};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
/// MPC-TLS commitment protocol configuration.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(try_from = "unchecked::MpcTlsConfigUnchecked")]
pub struct MpcTlsConfig {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
max_sent_records: Option<usize>,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
defer_decryption_from_start: bool,
/// Network settings.
network: NetworkSetting,
}
impl MpcTlsConfig {
/// Creates a new builder.
pub fn builder() -> MpcTlsConfigBuilder {
MpcTlsConfigBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> Option<usize> {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
fn validate(config: MpcTlsConfig) -> Result<MpcTlsConfig, MpcTlsConfigError> {
if config.max_recv_data_online > config.max_recv_data {
return Err(ErrorRepr::InvalidValue {
name: "max_recv_data_online",
reason: format!(
"must be <= max_recv_data ({} > {})",
config.max_recv_data_online, config.max_recv_data
),
}
.into());
}
Ok(config)
}
/// Builder for [`MpcTlsConfig`].
#[derive(Debug, Default)]
pub struct MpcTlsConfigBuilder {
max_sent_data: Option<usize>,
max_sent_records: Option<usize>,
max_recv_data_online: Option<usize>,
max_recv_data: Option<usize>,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: Option<bool>,
network: Option<NetworkSetting>,
}
impl MpcTlsConfigBuilder {
/// Sets the maximum number of bytes that can be sent.
pub fn max_sent_data(mut self, max_sent_data: usize) -> Self {
self.max_sent_data = Some(max_sent_data);
self
}
/// Sets the maximum number of application data records that can be sent.
pub fn max_sent_records(mut self, max_sent_records: usize) -> Self {
self.max_sent_records = Some(max_sent_records);
self
}
/// Sets the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(mut self, max_recv_data_online: usize) -> Self {
self.max_recv_data_online = Some(max_recv_data_online);
self
}
/// Sets the maximum number of bytes that can be received.
pub fn max_recv_data(mut self, max_recv_data: usize) -> Self {
self.max_recv_data = Some(max_recv_data);
self
}
/// Sets the maximum number of received application data records that can
/// be decrypted online.
pub fn max_recv_records_online(mut self, max_recv_records_online: usize) -> Self {
self.max_recv_records_online = Some(max_recv_records_online);
self
}
/// Sets whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(mut self, defer_decryption_from_start: bool) -> Self {
self.defer_decryption_from_start = Some(defer_decryption_from_start);
self
}
/// Sets the network settings.
pub fn network(mut self, network: NetworkSetting) -> Self {
self.network = Some(network);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<MpcTlsConfig, MpcTlsConfigError> {
let Self {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
} = self;
let max_sent_data = max_sent_data.ok_or(ErrorRepr::MissingField {
name: "max_sent_data",
})?;
let max_recv_data_online = max_recv_data_online.unwrap_or(DEFAULT_MAX_RECV_ONLINE);
let max_recv_data = max_recv_data.ok_or(ErrorRepr::MissingField {
name: "max_recv_data",
})?;
let defer_decryption_from_start = defer_decryption_from_start.unwrap_or(true);
let network = network.unwrap_or_default();
validate(MpcTlsConfig {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
})
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
#[default]
Latency,
}
/// Error for [`MpcTlsConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct MpcTlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
#[error("invalid value for field({name}): {reason}")]
InvalidValue { name: &'static str, reason: String },
}
mod unchecked {
use super::*;
#[derive(Deserialize)]
pub(super) struct MpcTlsConfigUnchecked {
max_sent_data: usize,
max_sent_records: Option<usize>,
max_recv_data_online: usize,
max_recv_data: usize,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: bool,
network: NetworkSetting,
}
impl TryFrom<MpcTlsConfigUnchecked> for MpcTlsConfig {
type Error = MpcTlsConfigError;
fn try_from(value: MpcTlsConfigUnchecked) -> Result<Self, Self::Error> {
validate(MpcTlsConfig {
max_sent_data: value.max_sent_data,
max_sent_records: value.max_sent_records,
max_recv_data_online: value.max_recv_data_online,
max_recv_data: value.max_recv_data,
max_recv_records_online: value.max_recv_records_online,
defer_decryption_from_start: value.defer_decryption_from_start,
network: value.network,
})
}
}
}

View File

@@ -0,0 +1,56 @@
//! Verifier configuration.
use serde::{Deserialize, Serialize};
use crate::webpki::RootCertStore;
/// Verifier configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerifierConfig {
root_store: RootCertStore,
}
impl VerifierConfig {
/// Creates a new builder.
pub fn builder() -> VerifierConfigBuilder {
VerifierConfigBuilder::default()
}
/// Returns the root certificate store.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
}
/// Builder for [`VerifierConfig`].
#[derive(Debug, Default)]
pub struct VerifierConfigBuilder {
root_store: Option<RootCertStore>,
}
impl VerifierConfigBuilder {
/// Sets the root certificate store.
pub fn root_store(mut self, root_store: RootCertStore) -> Self {
self.root_store = Some(root_store);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<VerifierConfig, VerifierConfigError> {
let root_store = self
.root_store
.ok_or(ErrorRepr::MissingField { name: "root_store" })?;
Ok(VerifierConfig { root_store })
}
}
/// Error for [`VerifierConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct VerifierConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -1,11 +1,11 @@
use rangeset::RangeSet; use rangeset::set::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>); pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> { impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?; f.write_str("{")?;
for range in self.0.iter_ranges() { for range in self.0.iter() {
write!(f, "{}..{}", range.start, range.end)?; write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) { if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?; f.write_str(", ")?;

View File

@@ -2,12 +2,13 @@
use aead::Payload as AeadPayload; use aead::Payload as AeadPayload;
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead}; use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
#[allow(deprecated)]
use generic_array::GenericArray; use generic_array::GenericArray;
use rand::{rngs::StdRng, Rng, SeedableRng}; use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_core::msgs::{ use tls_core::msgs::{
base::Payload, base::Payload,
codec::Codec, codec::Codec,
enums::{ContentType, HandshakeType, ProtocolVersion}, enums::{HandshakeType, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload}, handshake::{HandshakeMessagePayload, HandshakePayload},
message::{OpaqueMessage, PlainMessage}, message::{OpaqueMessage, PlainMessage},
}; };
@@ -15,7 +16,7 @@ use tls_core::msgs::{
use crate::{ use crate::{
connection::{TranscriptLength, VerifyData}, connection::{TranscriptLength, VerifyData},
fixtures::ConnectionFixture, fixtures::ConnectionFixture,
transcript::{Record, TlsTranscript}, transcript::{ContentType, Record, TlsTranscript},
}; };
/// The key used for encryption of the sent and received transcript. /// The key used for encryption of the sent and received transcript.
@@ -103,7 +104,7 @@ impl TranscriptGenerator {
let explicit_nonce: [u8; 8] = seq.to_be_bytes(); let explicit_nonce: [u8; 8] = seq.to_be_bytes();
let msg = PlainMessage { let msg = PlainMessage {
typ: ContentType::ApplicationData, typ: ContentType::ApplicationData.into(),
version: ProtocolVersion::TLSv1_2, version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext), payload: Payload::new(plaintext),
}; };
@@ -138,7 +139,7 @@ impl TranscriptGenerator {
handshake_message.encode(&mut plaintext); handshake_message.encode(&mut plaintext);
let msg = PlainMessage { let msg = PlainMessage {
typ: ContentType::Handshake, typ: ContentType::Handshake.into(),
version: ProtocolVersion::TLSv1_2, version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext.clone()), payload: Payload::new(plaintext.clone()),
}; };
@@ -180,6 +181,7 @@ fn aes_gcm_encrypt(
let mut nonce = [0u8; 12]; let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv); nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce); nonce[4..].copy_from_slice(&explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce); let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&key).unwrap(); let cipher = Aes128Gcm::new_from_slice(&key).unwrap();

View File

@@ -296,14 +296,14 @@ mod sha2 {
fn hash(&self, data: &[u8]) -> super::Hash { fn hash(&self, data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default(); let mut hasher = ::sha2::Sha256::default();
hasher.update(data); hasher.update(data);
super::Hash::new(hasher.finalize().as_slice()) super::Hash::new(hasher.finalize().as_ref())
} }
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash { fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default(); let mut hasher = ::sha2::Sha256::default();
hasher.update(prefix); hasher.update(prefix);
hasher.update(data); hasher.update(data);
super::Hash::new(hasher.finalize().as_slice()) super::Hash::new(hasher.finalize().as_ref())
} }
} }
} }

View File

@@ -12,176 +12,18 @@ pub mod merkle;
pub mod transcript; pub mod transcript;
pub mod webpki; pub mod webpki;
pub use rangeset; pub use rangeset;
pub mod config;
pub(crate) mod display; pub(crate) mod display;
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
connection::{HandshakeData, ServerName}, connection::ServerName,
transcript::{ transcript::{
encoding::EncoderSecret, Direction, PartialTranscript, Transcript, TranscriptCommitConfig, encoding::EncoderSecret, PartialTranscript, TranscriptCommitment, TranscriptSecret,
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
}, },
}; };
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the ranges of the transcript to be revealed.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigBuilderError(
ProveConfigBuilderErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
},
));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
}
/// Error for [`ProveConfigBuilder`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigBuilderError(#[from] ProveConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ProveConfigBuilderErrorRepr {
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
}
/// Request to prove statements about the connection.
#[derive(Debug, Serialize, Deserialize)]
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Transcript commitment configuration.
pub transcript_commit: Option<TranscriptCommitRequest>,
}
/// Prover output. /// Prover output.
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ProverOutput { pub struct ProverOutput {

View File

@@ -26,7 +26,11 @@ mod tls;
use std::{fmt, ops::Range}; use std::{fmt, ops::Range};
use rangeset::{Difference, IndexRanges, RangeSet, Union}; use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength; use crate::connection::TranscriptLength;
@@ -38,8 +42,7 @@ pub use commit::{
pub use proof::{ pub use proof::{
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError, TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
}; };
pub use tls::{Record, TlsTranscript}; pub use tls::{ContentType, Record, TlsTranscript};
pub use tls_core::msgs::enums::ContentType;
/// A transcript contains the plaintext of all application data communicated /// A transcript contains the plaintext of all application data communicated
/// between the Prover and the Server. /// between the Prover and the Server.
@@ -107,8 +110,14 @@ impl Transcript {
} }
Some( Some(
Subsequence::new(idx.clone(), data.index_ranges(idx)) Subsequence::new(
.expect("data is same length as index"), idx.clone(),
data.index(idx).fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
)
.expect("data is same length as index"),
) )
} }
@@ -130,11 +139,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()]; let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()]; let mut received = vec![0; self.received.len()];
for range in sent_idx.iter_ranges() { for range in sent_idx.iter() {
sent[range.clone()].copy_from_slice(&self.sent[range]); sent[range.clone()].copy_from_slice(&self.sent[range]);
} }
for range in recv_idx.iter_ranges() { for range in recv_idx.iter() {
received[range.clone()].copy_from_slice(&self.received[range]); received[range.clone()].copy_from_slice(&self.received[range]);
} }
@@ -187,12 +196,20 @@ pub struct CompressedPartialTranscript {
impl From<PartialTranscript> for CompressedPartialTranscript { impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self { fn from(uncompressed: PartialTranscript) -> Self {
Self { Self {
sent_authed: uncompressed sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
.sent Vec::new(),
.index_ranges(&uncompressed.sent_authed_idx), |mut acc, s| {
acc.extend_from_slice(s);
acc
},
),
received_authed: uncompressed received_authed: uncompressed
.received .received
.index_ranges(&uncompressed.received_authed_idx), .index(&uncompressed.received_authed_idx)
.fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
sent_idx: uncompressed.sent_authed_idx, sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx, recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(), sent_total: uncompressed.sent.len(),
@@ -208,7 +225,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0; let mut offset = 0;
for range in compressed.sent_idx.iter_ranges() { for range in compressed.sent_idx.iter() {
sent[range.clone()] sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]); .copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len(); offset += range.len();
@@ -216,7 +233,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0; let mut offset = 0;
for range in compressed.recv_idx.iter_ranges() { for range in compressed.recv_idx.iter() {
received[range.clone()] received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]); .copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len(); offset += range.len();
@@ -305,12 +322,16 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated. /// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> { pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len()).difference(&self.sent_authed_idx) (0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
} }
/// Returns the index of received data which haven't been authenticated. /// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> { pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len()).difference(&self.received_authed_idx) (0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
} }
/// Returns an iterator over the authenticated data in the transcript. /// Returns an iterator over the authenticated data in the transcript.
@@ -320,7 +341,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx), Direction::Received => (&self.received, &self.received_authed_idx),
}; };
authed.iter().map(|i| data[i]) authed.iter_values().map(move |i| data[i])
} }
/// Unions the authenticated data of this transcript with another. /// Unions the authenticated data of this transcript with another.
@@ -340,24 +361,20 @@ impl PartialTranscript {
"received data are not the same length" "received data are not the same length"
); );
for range in other for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]); self.sent[range.clone()].copy_from_slice(&other.sent[range]);
} }
for range in other for range in other
.received_authed_idx .received_authed_idx
.difference(&self.received_authed_idx) .difference(&self.received_authed_idx)
.iter_ranges()
{ {
self.received[range.clone()].copy_from_slice(&other.received[range]); self.received[range.clone()].copy_from_slice(&other.received[range]);
} }
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx); self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx); self.received_authed_idx
.union_mut(&other.received_authed_idx);
} }
/// Unions an authenticated subsequence into this transcript. /// Unions an authenticated subsequence into this transcript.
@@ -369,11 +386,11 @@ impl PartialTranscript {
match direction { match direction {
Direction::Sent => { Direction::Sent => {
seq.copy_to(&mut self.sent); seq.copy_to(&mut self.sent);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx); self.sent_authed_idx.union_mut(&seq.idx);
} }
Direction::Received => { Direction::Received => {
seq.copy_to(&mut self.received); seq.copy_to(&mut self.received);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx); self.received_authed_idx.union_mut(&seq.idx);
} }
} }
} }
@@ -384,10 +401,10 @@ impl PartialTranscript {
/// ///
/// * `value` - The value to set the unauthenticated bytes to /// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) { pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter_ranges() { for range in self.sent_unauthed().iter() {
self.sent[range].fill(value); self.sent[range].fill(value);
} }
for range in self.received_unauthed().iter_ranges() { for range in self.received_unauthed().iter() {
self.received[range].fill(value); self.received[range].fill(value);
} }
} }
@@ -402,13 +419,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) { pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction { match direction {
Direction::Sent => { Direction::Sent => {
for range in range.difference(&self.sent_authed_idx).iter_ranges() { for r in range.difference(&self.sent_authed_idx) {
self.sent[range].fill(value); self.sent[r].fill(value);
} }
} }
Direction::Received => { Direction::Received => {
for range in range.difference(&self.received_authed_idx).iter_ranges() { for r in range.difference(&self.received_authed_idx) {
self.received[range].fill(value); self.received[r].fill(value);
} }
} }
} }
@@ -486,7 +503,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds. /// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) { pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0; let mut offset = 0;
for range in self.idx.iter_ranges() { for range in self.idx.iter() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]); dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len(); offset += range.len();
} }
@@ -611,12 +628,7 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked, mut partial_transcript: CompressedPartialTranscriptUnchecked,
) { ) {
// Change the total to be less than the last range's end bound. // Change the total to be less than the last range's end bound.
let end = partial_transcript let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
partial_transcript.sent_total = end - 1; partial_transcript.sent_total = end - 1;

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt}; use std::{collections::HashSet, fmt};
use rangeset::{ToRangeSet, UnionMut}; use rangeset::set::ToRangeSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{

View File

@@ -1,6 +1,6 @@
use std::{collections::HashMap, fmt}; use std::{collections::HashMap, fmt};
use rangeset::{RangeSet, UnionMut}; use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
@@ -103,7 +103,7 @@ impl EncodingProof {
} }
expected_leaf.clear(); expected_leaf.clear();
for range in idx.iter_ranges() { for range in idx.iter() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf); encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
} }
expected_leaf.extend_from_slice(blinder.as_bytes()); expected_leaf.extend_from_slice(blinder.as_bytes());

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use bimap::BiMap; use bimap::BiMap;
use rangeset::{RangeSet, UnionMut}; use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
@@ -99,7 +99,7 @@ impl EncodingTree {
let blinder: Blinder = rand::random(); let blinder: Blinder = rand::random();
encoding.clear(); encoding.clear();
for range in idx.iter_ranges() { for range in idx.iter() {
provider provider
.provide_encoding(direction, range, &mut encoding) .provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?; .map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;

View File

@@ -1,6 +1,10 @@
//! Transcript proofs. //! Transcript proofs.
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut}; use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt}; use std::{collections::HashSet, fmt};
@@ -25,6 +29,9 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash { TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3, alg: HashAlgId::BLAKE3,
}, },
TranscriptCommitmentKind::Hash {
alg: HashAlgId::KECCAK256,
},
TranscriptCommitmentKind::Encoding, TranscriptCommitmentKind::Encoding,
]; ];
@@ -141,7 +148,7 @@ impl TranscriptProof {
} }
buffer.clear(); buffer.clear();
for range in idx.iter_ranges() { for range in idx.iter() {
buffer.extend_from_slice(&plaintext[range]); buffer.extend_from_slice(&plaintext[range]);
} }
@@ -363,7 +370,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) { if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx); self.query_idx.union(&direction, &idx);
} else { } else {
let missing = idx.difference(committed); let missing = idx.difference(committed).into_set();
return Err(TranscriptProofBuilderError::new( return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment, BuilderErrorKind::MissingCommitment,
format!( format!(
@@ -579,7 +586,7 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use rangeset::RangeSet; use rangeset::prelude::*;
use rstest::rstest; use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON}; use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -656,6 +663,7 @@ mod tests {
#[rstest] #[rstest]
#[case::sha256(HashAlgId::SHA256)] #[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)] #[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) { fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
let mut rng = rand::rngs::StdRng::seed_from_u64(0); let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default(); let provider = HashProvider::default();
@@ -704,6 +712,7 @@ mod tests {
#[rstest] #[rstest]
#[case::sha256(HashAlgId::SHA256)] #[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)] #[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) { fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
let mut rng = rand::rngs::StdRng::seed_from_u64(0); let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default(); let provider = HashProvider::default();

View File

@@ -10,10 +10,53 @@ use crate::{
use tls_core::msgs::{ use tls_core::msgs::{
alert::AlertMessagePayload, alert::AlertMessagePayload,
codec::{Codec, Reader}, codec::{Codec, Reader},
enums::{AlertDescription, ContentType, ProtocolVersion}, enums::{AlertDescription, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload}, handshake::{HandshakeMessagePayload, HandshakePayload},
}; };
/// TLS record content type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ContentType {
/// Change cipher spec protocol.
ChangeCipherSpec,
/// Alert protocol.
Alert,
/// Handshake protocol.
Handshake,
/// Application data protocol.
ApplicationData,
/// Heartbeat protocol.
Heartbeat,
/// Unknown protocol.
Unknown(u8),
}
impl From<ContentType> for tls_core::msgs::enums::ContentType {
fn from(content_type: ContentType) -> Self {
match content_type {
ContentType::ChangeCipherSpec => tls_core::msgs::enums::ContentType::ChangeCipherSpec,
ContentType::Alert => tls_core::msgs::enums::ContentType::Alert,
ContentType::Handshake => tls_core::msgs::enums::ContentType::Handshake,
ContentType::ApplicationData => tls_core::msgs::enums::ContentType::ApplicationData,
ContentType::Heartbeat => tls_core::msgs::enums::ContentType::Heartbeat,
ContentType::Unknown(id) => tls_core::msgs::enums::ContentType::Unknown(id),
}
}
}
impl From<tls_core::msgs::enums::ContentType> for ContentType {
fn from(content_type: tls_core::msgs::enums::ContentType) -> Self {
match content_type {
tls_core::msgs::enums::ContentType::ChangeCipherSpec => ContentType::ChangeCipherSpec,
tls_core::msgs::enums::ContentType::Alert => ContentType::Alert,
tls_core::msgs::enums::ContentType::Handshake => ContentType::Handshake,
tls_core::msgs::enums::ContentType::ApplicationData => ContentType::ApplicationData,
tls_core::msgs::enums::ContentType::Heartbeat => ContentType::Heartbeat,
tls_core::msgs::enums::ContentType::Unknown(id) => ContentType::Unknown(id),
}
}
}
/// A transcript of TLS records sent and received by the prover. /// A transcript of TLS records sent and received by the prover.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct TlsTranscript { pub struct TlsTranscript {

View File

@@ -53,6 +53,21 @@ impl RootCertStore {
pub fn empty() -> Self { pub fn empty() -> Self {
Self { roots: Vec::new() } Self { roots: Vec::new() }
} }
/// Creates a root certificate store with Mozilla root certificates.
///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
pub fn mozilla() -> Self {
Self {
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|cert| CertificateDer(cert.to_vec()))
.collect(),
}
}
} }
/// Server certificate verifier. /// Server certificate verifier.
@@ -82,8 +97,12 @@ impl ServerCertVerifier {
Ok(Self { roots }) Ok(Self { roots })
} }
/// Creates a new server certificate verifier with Mozilla root /// Creates a server certificate verifier with Mozilla root certificates.
/// certificates. ///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
pub fn mozilla() -> Self { pub fn mozilla() -> Self {
Self { Self {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),

View File

@@ -4,5 +4,7 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary. * [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary. * [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
* [Interactive_zk](./interactive_zk/README.md): Interactive Prover and Verifier session demonstrating zero-knowledge age verification using Noir.
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples. Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.

View File

@@ -4,6 +4,7 @@
use std::env; use std::env;
use anyhow::{anyhow, Result};
use clap::Parser; use clap::Parser;
use http_body_util::Empty; use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode}; use hyper::{body::Bytes, Request, StatusCode};
@@ -22,11 +23,18 @@ use tlsn::{
signing::Secp256k1Signer, signing::Secp256k1Signer,
Attestation, AttestationConfig, CryptoProvider, Secrets, Attestation, AttestationConfig, CryptoProvider, Secrets,
}, },
config::{CertificateDer, PrivateKeyDer, ProtocolConfig, RootCertStore}, config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
verifier::VerifierConfig,
},
connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength}, connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength},
prover::{state::Committed, ProveConfig, Prover, ProverConfig, ProverOutput, TlsConfig}, prover::{state::Committed, Prover, ProverOutput},
transcript::{ContentType, TranscriptCommitConfig}, transcript::{ContentType, TranscriptCommitConfig},
verifier::{Verifier, VerifierConfig, VerifierOutput}, verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
}; };
use tlsn_examples::ExampleType; use tlsn_examples::ExampleType;
use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript}; use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript};
@@ -45,7 +53,7 @@ struct Args {
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let args = Args::parse(); let args = Args::parse();
@@ -85,64 +93,63 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
uri: &str, uri: &str,
extra_headers: Vec<(&str, &str)>, extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType, example_type: &ExampleType,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<()> {
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into()); let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT") let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer")) .map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT); .unwrap_or(DEFAULT_FIXTURE_PORT);
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
));
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
// We must configure the amount of data we expect to exchange beforehand, which will
// be preprocessed prior to the connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create a new prover and perform necessary setup. // Create a new prover and perform necessary setup.
let prover = Prover::new(prover_config).setup(socket.compat()).await?; let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
.build()?,
socket.compat(),
)
.await?;
// Open a TCP connection to the server. // Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?; let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
// Bind the prover to the server connection. // Bind the prover to the server connection.
// The returned `mpc_tls_connection` is an MPC TLS connection to the server: all let (tls_connection, prover_fut) = prover
// data written to/read from it will be encrypted/decrypted using MPC with .connect(
// the notary. TlsClientConfig::builder()
let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?; .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); // Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
))
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the prover task to be run concurrently in the background. // Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover_fut);
// Attach the hyper HTTP client to the connection. // Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) = let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?; hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the HTTP task to be run concurrently in the background. // Spawn the HTTP task to be run concurrently in the background.
tokio::spawn(connection); tokio::spawn(connection);
@@ -163,7 +170,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
} }
let request = request_builder.body(Empty::<Bytes>::new())?; let request = request_builder.body(Empty::<Bytes>::new())?;
info!("Starting an MPC TLS connection with the server"); info!("Starting connection with the server");
// Send the request to the server and wait for the response. // Send the request to the server and wait for the response.
let response = request_sender.send_request(request).await?; let response = request_sender.send_request(request).await?;
@@ -240,7 +247,7 @@ async fn notarize(
config: &RequestConfig, config: &RequestConfig,
request_tx: Sender<AttestationRequest>, request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>, attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets), Box<dyn std::error::Error>> { ) -> Result<(Attestation, Secrets)> {
let mut builder = ProveConfig::builder(prover.transcript()); let mut builder = ProveConfig::builder(prover.transcript());
if let Some(config) = config.transcript_commit() { if let Some(config) = config.transcript_commit() {
@@ -283,15 +290,18 @@ async fn notarize(
// Send attestation request to notary. // Send attestation request to notary.
request_tx request_tx
.send(request.clone()) .send(request.clone())
.map_err(|_| "notary is not receiving attestation request".to_string())?; .map_err(|_| anyhow!("notary is not receiving attestation request"))?;
// Receive attestation from notary. // Receive attestation from notary.
let attestation = attestation_rx let attestation = attestation_rx
.await .await
.map_err(|err| format!("notary did not respond with attestation: {err}"))?; .map_err(|err| anyhow!("notary did not respond with attestation: {err}"))?;
// Signature verifier for the signature algorithm in the request.
let provider = CryptoProvider::default();
// Check the attestation is consistent with the Prover's view. // Check the attestation is consistent with the Prover's view.
request.validate(&attestation)?; request.validate(&attestation, &provider)?;
Ok((attestation, secrets)) Ok((attestation, secrets))
} }
@@ -300,7 +310,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S, socket: S,
request_rx: Receiver<AttestationRequest>, request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>, attestation_tx: Sender<Attestation>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<()> {
// Create a root certificate store with the server-fixture's self-signed // Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the // certificate. This is only required for offline testing with the
// server-fixture. // server-fixture.
@@ -312,7 +322,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap(); .unwrap();
let verifier = Verifier::new(verifier_config) let verifier = Verifier::new(verifier_config)
.setup(socket.compat()) .commit(socket.compat())
.await? .await?
.accept() .accept()
.await? .await?
@@ -322,6 +332,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let ( let (
VerifierOutput { VerifierOutput {
transcript_commitments, transcript_commitments,
encoder_secret,
.. ..
}, },
verifier, verifier,
@@ -382,12 +393,16 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone()) .server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments); .transcript_commitments(transcript_commitments);
if let Some(encoder_secret) = encoder_secret {
builder.encoder_secret(encoder_secret);
}
let attestation = builder.build(&provider)?; let attestation = builder.build(&provider)?;
// Send attestation to prover. // Send attestation to prover.
attestation_tx attestation_tx
.send(attestation) .send(attestation)
.map_err(|_| "prover is not receiving attestation".to_string())?; .map_err(|_| anyhow!("prover is not receiving attestation"))?;
Ok(()) Ok(())
} }

View File

@@ -12,8 +12,8 @@ use tlsn::{
signing::VerifyingKey, signing::VerifyingKey,
CryptoProvider, CryptoProvider,
}, },
config::{CertificateDer, RootCertStore},
verifier::ServerCertVerifier, verifier::ServerCertVerifier,
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_examples::ExampleType; use tlsn_examples::ExampleType;
use tlsn_server_fixture_certs::CA_CERT_DER; use tlsn_server_fixture_certs::CA_CERT_DER;

View File

@@ -12,11 +12,18 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument; use tracing::instrument;
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore}, config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
verifier::VerifierConfig,
},
connection::ServerName, connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, prover::Prover,
transcript::PartialTranscript, transcript::PartialTranscript,
verifier::{Verifier, VerifierConfig, VerifierOutput}, verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT; use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -70,52 +77,52 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
assert_eq!(uri.scheme().unwrap().as_str(), "https"); assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host(); let server_domain = uri.authority().unwrap().host();
// Create a root certificate store with the server-fixture's self-signed // Create a new prover and perform necessary setup.
// certificate. This is only required for offline testing with the let prover = Prover::new(ProverConfig::builder().build()?)
// server-fixture. .commit(
let mut tls_config_builder = TlsConfig::builder(); TlsCommitConfig::builder()
tls_config_builder.root_store(RootCertStore { // Select the TLS commitment protocol.
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], .protocol(
}); MpcTlsConfig::builder()
let tls_config = tls_config_builder.build().unwrap(); // We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// Set up protocol configuration for prover. // connection. Reducing these limits will improve
let mut prover_config_builder = ProverConfig::builder(); // performance.
prover_config_builder .max_sent_data(tlsn_examples::MAX_SENT_DATA)
.server_name(ServerName::Dns(server_domain.try_into().unwrap())) .max_recv_data(tlsn_examples::MAX_RECV_DATA)
.tls_config(tls_config) .build()?,
.protocol_config( )
ProtocolConfig::builder() .build()?,
.max_sent_data(MAX_SENT_DATA) verifier_socket.compat(),
.max_recv_data(MAX_RECV_DATA) )
.build()
.unwrap(),
);
let prover_config = prover_config_builder.build().unwrap();
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await?; .await?;
// Connect to TLS Server. // Open a TCP connection to the server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?; let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Pass server connection into the prover. // Bind the prover to the server connection.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?; let (tls_connection, prover_fut) = prover
.connect(
// Wrap the connection in a TokioIo compatibility layer to use it with hyper. TlsClientConfig::builder()
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background. // Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake. // MPC-TLS Handshake.
let (mut request_sender, connection) = let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?; hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the connection to run in the background. // Spawn the connection to run in the background.
tokio::spawn(connection); tokio::spawn(connection);
@@ -187,16 +194,21 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let verifier = Verifier::new(verifier_config); let verifier = Verifier::new(verifier_config);
// Validate the proposed configuration and then run the TLS commitment protocol. // Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.setup(socket.compat()).await?; let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the // This is the opportunity to ensure the prover does not attempt to overload the
// verifier. // verifier.
let reject = if verifier.config().max_sent_data() > MAX_SENT_DATA { let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
Some("max_sent_data is too large") {
} else if verifier.config().max_recv_data() > MAX_RECV_DATA { if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_recv_data is too large") Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else { } else {
None Some("expecting to use MPC-TLS")
}; };
if reject.is_some() { if reject.is_some() {
@@ -210,7 +222,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
// Validate the proving request and then verify. // Validate the proving request and then verify.
let verifier = verifier.verify().await?; let verifier = verifier.verify().await?;
if verifier.request().handshake.is_none() { if !verifier.request().server_identity() {
let verifier = verifier let verifier = verifier
.reject(Some("expecting to verify the server name")) .reject(Some("expecting to verify the server name"))
.await?; .await?;

View File

@@ -22,24 +22,27 @@ use spansy::{
http::{BodyContent, Requests, Responses}, http::{BodyContent, Requests, Responses},
Spanned, Spanned,
}; };
use tls_server_fixture::CA_CERT_DER; use tls_server_fixture::{CA_CERT_DER, SERVER_DOMAIN};
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore}, config::{
prove::{ProveConfig, ProveConfigBuilder},
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
connection::ServerName, connection::ServerName,
hash::HashAlgId, hash::HashAlgId,
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig}, prover::Prover,
transcript::{ transcript::{
hash::{PlaintextHash, PlaintextHashSecret}, hash::{PlaintextHash, PlaintextHashSecret},
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind, Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
TranscriptSecret, TranscriptSecret,
}, },
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_examples::MAX_RECV_DATA; use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tokio::io::AsyncWriteExt; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tlsn_examples::MAX_SENT_DATA;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument; use tracing::instrument;
@@ -61,51 +64,52 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.ok_or_else(|| anyhow::anyhow!("URI must have authority"))? .ok_or_else(|| anyhow::anyhow!("URI must have authority"))?
.host(); .host();
// Create a root certificate store with the server-fixture's self-signed // Create a new prover and perform necessary setup.
// certificate. This is only required for offline testing with the let prover = Prover::new(ProverConfig::builder().build()?)
// server-fixture. .commit(
let mut tls_config_builder = TlsConfig::builder(); TlsCommitConfig::builder()
tls_config_builder.root_store(RootCertStore { // Select the TLS commitment protocol.
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], .protocol(
}); MpcTlsConfig::builder()
let tls_config = tls_config_builder.build()?; // We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// Set up protocol configuration for prover. // connection. Reducing these limits will improve
let mut prover_config_builder = ProverConfig::builder(); // performance.
prover_config_builder .max_sent_data(MAX_SENT_DATA)
.server_name(ServerName::Dns(server_domain.try_into()?)) .max_recv_data(MAX_RECV_DATA)
.tls_config(tls_config) .build()?,
.protocol_config( )
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?, .build()?,
); verifier_socket.compat(),
)
let prover_config = prover_config_builder.build()?;
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await?; .await?;
// Connect to TLS Server. // Open a TCP connection to the server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?; let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Pass server connection into the prover. // Bind the prover to the server connection.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?; let (tls_connection, prover_fut) = prover
.connect(
// Wrap the connection in a TokioIo compatibility layer to use it with hyper. TlsClientConfig::builder()
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat()); .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background. // Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake. // MPC-TLS Handshake.
let (mut request_sender, connection) = let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?; hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the connection to run in the background. // Spawn the connection to run in the background.
tokio::spawn(connection); tokio::spawn(connection);
@@ -320,7 +324,7 @@ fn prepare_zk_proof_input(
hasher.update(&blinder); hasher.update(&blinder);
let computed_hash = hasher.finalize(); let computed_hash = hasher.finalize();
if committed_hash != computed_hash.as_slice() { if committed_hash != computed_hash.as_ref() as &[u8] {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"Computed hash does not match committed hash" "Computed hash does not match committed hash"
)); ));

View File

@@ -7,11 +7,12 @@ use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_h
use serde_json::Value; use serde_json::Value;
use tls_server_fixture::CA_CERT_DER; use tls_server_fixture::CA_CERT_DER;
use tlsn::{ use tlsn::{
config::{CertificateDer, RootCertStore}, config::{tls_commit::TlsCommitProtocolConfig, verifier::VerifierConfig},
connection::ServerName, connection::ServerName,
hash::HashAlgId, hash::HashAlgId,
transcript::{Direction, PartialTranscript}, transcript::{Direction, PartialTranscript},
verifier::{Verifier, VerifierConfig, VerifierOutput}, verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA}; use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tlsn_server_fixture_certs::SERVER_DOMAIN; use tlsn_server_fixture_certs::SERVER_DOMAIN;
@@ -24,28 +25,33 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
socket: T, socket: T,
mut extra_socket: T, mut extra_socket: T,
) -> Result<PartialTranscript> { ) -> Result<PartialTranscript> {
// Create a root certificate store with the server-fixture's self-signed let verifier = Verifier::new(
// certificate. This is only required for offline testing with the VerifierConfig::builder()
// server-fixture. // Create a root certificate store with the server-fixture's self-signed
let verifier_config = VerifierConfig::builder() // certificate. This is only required for offline testing with the
.root_store(RootCertStore { // server-fixture.
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], .root_store(RootCertStore {
}) roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
.build()?; })
.build()?,
let verifier = Verifier::new(verifier_config); );
// Validate the proposed configuration and then run the TLS commitment protocol. // Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.setup(socket.compat()).await?; let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the // This is the opportunity to ensure the prover does not attempt to overload the
// verifier. // verifier.
let reject = if verifier.config().max_sent_data() > MAX_SENT_DATA { let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
Some("max_sent_data is too large") {
} else if verifier.config().max_recv_data() > MAX_RECV_DATA { if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_recv_data is too large") Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else { } else {
None Some("expecting to use MPC-TLS")
}; };
if reject.is_some() { if reject.is_some() {
@@ -60,7 +66,7 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
let verifier = verifier.verify().await?; let verifier = verifier.verify().await?;
let request = verifier.request(); let request = verifier.request();
if request.handshake.is_none() || request.transcript.is_none() { if !request.server_identity() || request.reveal().is_none() {
let verifier = verifier let verifier = verifier
.reject(Some( .reject(Some(
"expecting to verify the server name and transcript data", "expecting to verify the server name and transcript data",

View File

@@ -9,6 +9,7 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096; pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true; pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false; pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench { pub const WARM_UP_BENCH: Bench = Bench {
group: None, group: None,
@@ -20,6 +21,7 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096, download_size: 4096,
defer_decryption: true, defer_decryption: true,
memory_profile: false, memory_profile: false,
reveal_all: true,
}; };
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -79,6 +81,8 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>, pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")] #[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>, pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -97,6 +101,8 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>, pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")] #[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>, pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
} }
impl BenchItem { impl BenchItem {
@@ -132,6 +138,10 @@ impl BenchItem {
if self.memory_profile.is_none() { if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile; self.memory_profile = group.memory_profile;
} }
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
} }
pub fn into_bench(&self) -> Bench { pub fn into_bench(&self) -> Bench {
@@ -145,6 +155,7 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE), download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION), defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE), memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
} }
} }
} }
@@ -164,6 +175,8 @@ pub struct Bench {
pub defer_decryption: bool, pub defer_decryption: bool,
#[serde(rename = "memory-profile")] #[serde(rename = "memory-profile")]
pub memory_profile: bool, pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -22,7 +22,10 @@ pub enum CmdOutput {
GetTests(Vec<String>), GetTests(Vec<String>),
Test(TestOutput), Test(TestOutput),
Bench(BenchOutput), Bench(BenchOutput),
Fail { reason: Option<String> }, #[cfg(target_arch = "wasm32")]
Fail {
reason: Option<String>,
},
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -5,9 +5,15 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
use harness_core::bench::{Bench, ProverMetrics}; use harness_core::bench::{Bench, ProverMetrics};
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore}, config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
},
connection::ServerName, connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, prover::Prover,
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -22,41 +28,47 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let sent = verifier_io.sent(); let sent = verifier_io.sent();
let recv = verifier_io.recv(); let recv = verifier_io.recv();
let mut builder = ProtocolConfig::builder(); let prover = Prover::new(ProverConfig::builder().build()?);
builder.max_sent_data(config.upload_size);
builder.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder.max_recv_data(config.download_size + RECV_PADDING);
let protocol_config = builder.build()?;
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
let prover = Prover::new(
ProverConfig::builder()
.tls_config(tls_config)
.protocol_config(protocol_config)
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.build()?,
);
let time_start = web_time::Instant::now(); let time_start = web_time::Instant::now();
let prover = prover.setup(verifier_io).await?; let prover = prover
.commit(
TlsCommitConfig::builder()
.protocol({
let mut builder = MpcTlsConfig::builder()
.max_sent_data(config.upload_size)
.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder = builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder
.max_recv_data(config.download_size + RECV_PADDING)
.build()
}?)
.build()?,
verifier_io,
)
.await?;
let time_preprocess = time_start.elapsed().as_millis(); let time_preprocess = time_start.elapsed().as_millis();
let time_start_online = web_time::Instant::now(); let time_start_online = web_time::Instant::now();
let uploaded_preprocess = sent.load(Ordering::Relaxed); let uploaded_preprocess = sent.load(Ordering::Relaxed);
let downloaded_preprocess = recv.load(Ordering::Relaxed); let downloaded_preprocess = recv.load(Ordering::Relaxed);
let (mut conn, prover_fut) = prover.connect(provider.provide_server_io().await?).await?; let (mut conn, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
provider.provide_server_io().await?,
)
.await?;
let (_, mut prover) = futures::try_join!( let (_, mut prover) = futures::try_join!(
async { async {
@@ -86,14 +98,27 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript()); let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder builder
.server_identity() .server_identity()
.reveal_sent(&(0..sent_len))? .reveal_sent(&reveal_sent_range)?
.reveal_recv(&(0..recv_len))?; .reveal_recv(&reveal_recv_range)?;
let config = builder.build()?; let prove_config = builder.build()?;
prover.prove(&config).await?; prover.prove(&prove_config).await?;
prover.close().await?; prover.close().await?;
let time_total = time_start.elapsed().as_millis(); let time_total = time_start.elapsed().as_millis();

View File

@@ -2,8 +2,9 @@ use anyhow::Result;
use harness_core::bench::Bench; use harness_core::bench::Bench;
use tlsn::{ use tlsn::{
config::{CertificateDer, RootCertStore}, config::verifier::VerifierConfig,
verifier::{Verifier, VerifierConfig}, verifier::Verifier,
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_server_fixture_certs::CA_CERT_DER; use tlsn_server_fixture_certs::CA_CERT_DER;
@@ -19,7 +20,7 @@ pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()
); );
let verifier = verifier let verifier = verifier
.setup(provider.provide_proto_io().await?) .commit(provider.provide_proto_io().await?)
.await? .await?
.accept() .accept()
.await? .await?

View File

@@ -1,10 +1,17 @@
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore}, config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
verifier::VerifierConfig,
},
connection::ServerName, connection::ServerName,
hash::HashAlgId, hash::HashAlgId,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, prover::Prover,
transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind}, transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind},
verifier::{Verifier, VerifierConfig, VerifierOutput}, verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN}; use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -21,35 +28,35 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier); crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) { async fn prover(provider: &IoProvider) {
let mut tls_config_builder = TlsConfig::builder(); let prover = Prover::new(ProverConfig::builder().build().unwrap())
tls_config_builder.root_store(RootCertStore { .commit(
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], TlsCommitConfig::builder()
}); .protocol(
MpcTlsConfig::builder()
let tls_config = tls_config_builder.build().unwrap(); .max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()); .defer_decryption_from_start(true)
let prover = Prover::new( .build()
ProverConfig::builder() .unwrap(),
.server_name(server_name) )
.tls_config(tls_config) .build()
.protocol_config( .unwrap(),
ProtocolConfig::builder() provider.provide_proto_io().await.unwrap(),
.max_sent_data(MAX_SENT_DATA) )
.max_recv_data(MAX_RECV_DATA) .await
.defer_decryption_from_start(true) .unwrap();
.build()
.unwrap(),
)
.build()
.unwrap(),
)
.setup(provider.provide_proto_io().await.unwrap())
.await
.unwrap();
let (tls_connection, prover_fut) = prover let (tls_connection, prover_fut) = prover
.connect(provider.provide_server_io().await.unwrap()) .connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap(),
provider.provide_server_io().await.unwrap(),
)
.await .await
.unwrap(); .unwrap();
@@ -120,7 +127,7 @@ async fn verifier(provider: &IoProvider) {
.unwrap(); .unwrap();
let verifier = Verifier::new(config) let verifier = Verifier::new(config)
.setup(provider.provide_proto_io().await.unwrap()) .commit(provider.provide_proto_io().await.unwrap())
.await .await
.unwrap() .unwrap()
.accept() .accept()

View File

@@ -7,10 +7,9 @@ publish = false
[dependencies] [dependencies]
tlsn-harness-core = { workspace = true } tlsn-harness-core = { workspace = true }
# tlsn-server-fixture = { workspace = true } # tlsn-server-fixture = { workspace = true }
charming = { version = "0.5.1", features = ["ssr"] } charming = { version = "0.6.0", features = ["ssr"] }
csv = "1.3.0"
clap = { workspace = true, features = ["derive", "env"] } clap = { workspace = true, features = ["derive", "env"] }
itertools = "0.14.0" polars = { version = "0.44", features = ["csv", "lazy"] }
toml = { workspace = true } toml = { workspace = true }

View File

@@ -0,0 +1,111 @@
# TLSNotary Benchmark Plot Tool
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
## Usage
```bash
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
```
### Arguments
- `<TOML>` - Path to Bench.toml file defining benchmark structure
- `<CSV>...` - One or more CSV files with benchmark results
### Options
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
- Number of labels must match number of CSV files
- `--min-max-band` - Add min/max bands to plots showing variance
- `-h, --help` - Print help information
## Examples
### Single Dataset
```bash
tlsn-harness-plot bench.toml results.csv
```
Generates plots from a single benchmark run.
### Compare Two Runs
```bash
tlsn-harness-plot bench.toml before.csv after.csv \
--labels "Before Optimization" "After Optimization"
```
Overlays two datasets to compare performance improvements.
### Multiple Datasets
```bash
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
--labels "Native" "Browser" "WASM"
```
Compare three different runtime environments.
### With Min/Max Bands
```bash
tlsn-harness-plot bench.toml run1.csv run2.csv \
--labels "Config A" "Config B" \
--min-max-band
```
Shows variance ranges for each dataset.
## Output Files
The tool generates two files per benchmark group:
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
- `<output>.svg` - Static SVG image for documentation
Default output filenames:
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
## Plot Format
Each dataset displays:
- **Solid line** - Total runtime (preprocessing + online phase)
- **Dashed line** - Online phase only
- **Shaded area** (optional) - Min/max variance bands
Different datasets automatically use distinct colors for easy comparison.
## CSV Format
Expected columns in each CSV file:
- `group` - Benchmark group name (must match TOML)
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
- `latency` - Network latency in ms (for latency plots)
- `time_preprocess` - Preprocessing time in ms
- `time_online` - Online phase time in ms
- `time_total` - Total runtime in ms
## TOML Format
The benchmark TOML file defines groups with either:
```toml
[[group]]
name = "my_benchmark"
protocol_latency = 50 # Fixed latency for bandwidth plots
# OR
bandwidth = 10000 # Fixed bandwidth for latency plots
```
All datasets must use the same TOML file to ensure consistent benchmark structure.
## Tips
- Use descriptive labels to make plots self-documenting
- Keep CSV files from the same benchmark configuration for valid comparisons
- Min/max bands are useful for showing stability but can clutter plots with many datasets
- Interactive HTML plots support zooming and hovering for detailed values

View File

@@ -1,17 +1,18 @@
use std::f32; use std::f32;
use charming::{ use charming::{
Chart, HtmlRenderer, Chart, HtmlRenderer, ImageRenderer,
component::{Axis, Legend, Title}, component::{Axis, Legend, Title},
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger}, element::{
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
Trigger,
},
series::Line, series::Line,
theme::Theme, theme::Theme,
}; };
use clap::Parser; use clap::Parser;
use harness_core::bench::{BenchItems, Measurement}; use harness_core::bench::BenchItems;
use itertools::Itertools; use polars::prelude::*;
const THEME: Theme = Theme::Default;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about)] #[command(author, version, about)]
@@ -19,72 +20,131 @@ struct Cli {
/// Path to the Bench.toml file with benchmark spec /// Path to the Bench.toml file with benchmark spec
toml: String, toml: String,
/// Path to the CSV file with benchmark results /// Paths to CSV files with benchmark results (one or more)
csv: String, csv: Vec<String>,
/// Prover kind: native or browser /// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
#[arg(short, long, value_enum, default_value = "native")] #[arg(short, long, num_args = 0..)]
prover_kind: ProverKind, labels: Vec<String>,
/// Add min/max bands to plots /// Add min/max bands to plots
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
min_max_band: bool, min_max_band: bool,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ProverKind {
Native,
Browser,
}
impl std::fmt::Display for ProverKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProverKind::Native => write!(f, "Native"),
ProverKind::Browser => write!(f, "Browser"),
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse(); let cli = Cli::parse();
let mut rdr = csv::Reader::from_path(&cli.csv)?; if cli.csv.is_empty() {
return Err("At least one CSV file must be provided".into());
}
// Generate labels if not provided
let labels: Vec<String> = if cli.labels.is_empty() {
cli.csv
.iter()
.enumerate()
.map(|(i, _)| format!("Dataset {}", i + 1))
.collect()
} else if cli.labels.len() != cli.csv.len() {
return Err(format!(
"Number of labels ({}) must match number of CSV files ({})",
cli.labels.len(),
cli.csv.len()
)
.into());
} else {
cli.labels.clone()
};
// Load all CSVs and add dataset label
let mut dfs = Vec::new();
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
let mut df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
.finish()?;
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
df.with_column(label_series)?;
dfs.push(df);
}
// Combine all dataframes
let df = dfs
.into_iter()
.reduce(|acc, df| acc.vstack(&df).unwrap())
.unwrap();
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?; let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
let groups = items.group; let groups = items.group;
// Prepare data for plotting.
let all_data: Vec<Measurement> = rdr
.deserialize::<Measurement>()
.collect::<Result<Vec<_>, _>>()?;
for group in groups { for group in groups {
if group.protocol_latency.is_some() { // Determine which field varies in benches for this group
let latency = group.protocol_latency.unwrap(); let benches_in_group: Vec<_> = items
plot_runtime_vs( .bench
&all_data, .iter()
cli.min_max_band, .filter(|b| b.group.as_deref() == Some(&group.name))
&group.name, .collect();
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth", if benches_in_group.is_empty() {
format!("{} ms Latency, {} mode", latency, cli.prover_kind), continue;
"runtime_vs_bandwidth.html",
"Bandwidth (Mbps)",
)?;
} }
if group.bandwidth.is_some() { // Check which field has varying values
let bandwidth = group.bandwidth.unwrap(); let bandwidth_varies = benches_in_group
.windows(2)
.any(|w| w[0].bandwidth != w[1].bandwidth);
let latency_varies = benches_in_group
.windows(2)
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
let download_size_varies = benches_in_group
.windows(2)
.any(|w| w[0].download_size != w[1].download_size);
if download_size_varies {
let upload_size = group.upload_size.unwrap_or(1024);
plot_runtime_vs( plot_runtime_vs(
&all_data, &df,
&labels,
cli.min_max_band, cli.min_max_band,
&group.name, &group.name,
|r| r.latency as f32, "download_size",
1.0 / 1024.0, // bytes to KB
"Runtime vs Response Size",
format!("{} bytes upload size", upload_size),
"runtime_vs_download_size",
"Response Size (KB)",
true, // legend on left
)?;
} else if bandwidth_varies {
let latency = group.protocol_latency.unwrap_or(50);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"bandwidth",
1.0 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency", latency),
"runtime_vs_bandwidth",
"Bandwidth (Mbps)",
false, // legend on right
)?;
} else if latency_varies {
let bandwidth = group.bandwidth.unwrap_or(1000);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"latency",
1.0,
"Runtime vs Latency", "Runtime vs Latency",
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind), format!("{} bps bandwidth", bandwidth),
"runtime_vs_latency.html", "runtime_vs_latency",
"Latency (ms)", "Latency (ms)",
true, // legend on left
)?; )?;
} }
} }
@@ -92,83 +152,51 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
struct DataPoint {
min: f32,
mean: f32,
max: f32,
}
struct Points {
preprocess: DataPoint,
online: DataPoint,
total: DataPoint,
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn plot_runtime_vs<Fx>( fn plot_runtime_vs(
all_data: &[Measurement], df: &DataFrame,
labels: &[String],
show_min_max: bool, show_min_max: bool,
group: &str, group: &str,
x_value: Fx, x_col: &str,
x_scale: f32,
title: &str, title: &str,
subtitle: String, subtitle: String,
output_file: &str, output_file: &str,
x_axis_label: &str, x_axis_label: &str,
) -> Result<Chart, Box<dyn std::error::Error>> legend_left: bool,
where ) -> Result<Chart, Box<dyn std::error::Error>> {
Fx: Fn(&Measurement) -> f32, let stats_df = df
{ .clone()
fn data_point(values: &[f32]) -> DataPoint { .lazy()
let mean = values.iter().copied().sum::<f32>() / values.len() as f32; .filter(col("group").eq(lit(group)))
let max = values.iter().copied().reduce(f32::max).unwrap_or_default(); .with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
let min = values.iter().copied().reduce(f32::min).unwrap_or_default(); .with_columns([
DataPoint { min, mean, max } (col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
} (col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
])
.group_by([col("x"), col("dataset_label")])
.agg([
col("preprocess").min().alias("preprocess_min"),
col("preprocess").mean().alias("preprocess_mean"),
col("preprocess").max().alias("preprocess_max"),
col("online").min().alias("online_min"),
col("online").mean().alias("online_mean"),
col("online").max().alias("online_max"),
col("total").min().alias("total_min"),
col("total").mean().alias("total_mean"),
col("total").max().alias("total_max"),
])
.sort(["dataset_label", "x"], Default::default())
.collect()?;
let stats: Vec<(f32, Points)> = all_data // Build legend entries
.iter() let mut legend_data = Vec::new();
.filter(|r| r.group.as_deref() == Some(group)) for label in labels {
.map(|r| { legend_data.push(format!("Total Mean ({})", label));
( legend_data.push(format!("Online Mean ({})", label));
x_value(r), }
r.time_preprocess as f32 / 1000.0, // ms to s
r.time_online as f32 / 1000.0,
r.time_total as f32 / 1000.0,
)
})
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.chunk_by(|entry| entry.0)
.into_iter()
.map(|(x, group)| {
let group_vec: Vec<_> = group.collect();
let preprocess = data_point(
&group_vec
.iter()
.map(|(_, t, _, _)| *t)
.collect::<Vec<f32>>(),
);
let online = data_point(
&group_vec
.iter()
.map(|(_, _, t, _)| *t)
.collect::<Vec<f32>>(),
);
let total = data_point(
&group_vec
.iter()
.map(|(_, _, _, t)| *t)
.collect::<Vec<f32>>(),
);
(
x,
Points {
preprocess,
online,
total,
},
)
})
.collect();
let mut chart = Chart::new() let mut chart = Chart::new()
.title( .title(
@@ -179,14 +207,6 @@ where
.subtext_style(TextStyle::new().font_size(16)), .subtext_style(TextStyle::new().font_size(16)),
) )
.tooltip(Tooltip::new().trigger(Trigger::Axis)) .tooltip(Tooltip::new().trigger(Trigger::Axis))
.legend(
Legend::new()
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
.top("80")
.right("110")
.orient(Orient::Vertical)
.item_gap(10),
)
.x_axis( .x_axis(
Axis::new() Axis::new()
.name(x_axis_label) .name(x_axis_label)
@@ -205,73 +225,156 @@ where
.name_text_style(TextStyle::new().font_size(21)), .name_text_style(TextStyle::new().font_size(21)),
); );
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean); // Add legend with conditional positioning
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean); let legend = Legend::new()
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean); .data(legend_data)
.top("80")
.orient(Orient::Vertical)
.item_gap(10);
if show_min_max { let legend = if legend_left {
chart = add_min_max_band( legend.left("110")
chart, } else {
&stats, legend.right("110")
"Preprocess Min/Max", };
|p| &p.preprocess,
"#ccc", chart = chart.legend(legend);
);
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc"); // Define colors for each dataset
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc"); let colors = vec![
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
];
for (idx, label) in labels.iter().enumerate() {
let color = colors.get(idx % colors.len()).unwrap();
// Total time - solid line
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Total Mean ({})", label),
"total_mean",
false,
color,
)?;
// Online time - dashed line (same color as total)
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Online Mean ({})", label),
"online_mean",
true,
color,
)?;
if show_min_max {
chart = add_dataset_min_max_band(
&chart,
&stats_df,
label,
&format!("Total Min/Max ({})", label),
"total",
color,
)?;
}
} }
// Save the chart as HTML file. // Save the chart as HTML file (no theme)
HtmlRenderer::new(title, 1000, 800) HtmlRenderer::new(title, 1000, 800)
.theme(THEME) .save(&chart, &format!("{}.html", output_file))
.save(&chart, output_file) .unwrap();
// Save SVG with default theme
ImageRenderer::new(1000, 800)
.theme(Theme::Default)
.save(&chart, &format!("{}.svg", output_file))
.unwrap();
// Save SVG with dark theme
ImageRenderer::new(1000, 800)
.theme(Theme::Dark)
.save(&chart, &format!("{}_dark.svg", output_file))
.unwrap(); .unwrap();
Ok(chart) Ok(chart)
} }
fn add_mean_series( fn add_dataset_series(
chart: Chart, chart: &Chart,
stats: &[(f32, Points)], df: &DataFrame,
name: &str, dataset_label: &str,
extract: impl Fn(&Points) -> f32, series_name: &str,
) -> Chart { col_name: &str,
chart.series( dashed: bool,
Line::new() color: &str,
.name(name) ) -> Result<Chart, Box<dyn std::error::Error>> {
.data( // Filter for specific dataset
stats let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
.iter() let filtered = df.filter(&mask)?;
.map(|(x, points)| vec![*x, extract(points)])
.collect(), let x = filtered.column("x")?.f32()?;
) let y = filtered.column(col_name)?.f32()?;
.symbol_size(6),
) let data: Vec<Vec<f32>> = x
.into_iter()
.zip(y.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let mut line = Line::new()
.name(series_name)
.data(data)
.symbol_size(6)
.item_style(ItemStyle::new().color(color));
let mut line_style = LineStyle::new();
if dashed {
line_style = line_style.type_(LineStyleType::Dashed);
}
line = line.line_style(line_style.color(color));
Ok(chart.clone().series(line))
} }
fn add_min_max_band( fn add_dataset_min_max_band(
chart: Chart, chart: &Chart,
stats: &[(f32, Points)], df: &DataFrame,
dataset_label: &str,
name: &str, name: &str,
extract: impl Fn(&Points) -> &DataPoint, col_prefix: &str,
color: &str, color: &str,
) -> Chart { ) -> Result<Chart, Box<dyn std::error::Error>> {
chart.series( // Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
let max_data: Vec<Vec<f32>> = x
.into_iter()
.zip(max_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let min_data: Vec<Vec<f32>> = x
.into_iter()
.zip(min_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.rev()
.collect();
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
Ok(chart.clone().series(
Line::new() Line::new()
.name(name) .name(name)
.data( .data(data)
stats
.iter()
.map(|(x, points)| vec![*x, extract(points).max])
.chain(
stats
.iter()
.rev()
.map(|(x, points)| vec![*x, extract(points).min]),
)
.collect(),
)
.show_symbol(false) .show_symbol(false)
.line_style(LineStyle::new().opacity(0.0)) .line_style(LineStyle::new().opacity(0.0))
.area_style(AreaStyle::new().opacity(0.3).color(color)), .area_style(AreaStyle::new().opacity(0.3).color(color)),
) ))
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -32,18 +32,13 @@ use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy}; use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
pub enum Target { pub enum Target {
#[default]
Native, Native,
Browser, Browser,
} }
impl Default for Target {
fn default() -> Self {
Self::Native
}
}
struct Runner { struct Runner {
network: Network, network: Network,
server_fixture: ServerFixture, server_fixture: ServerFixture,

View File

@@ -0,0 +1,25 @@
#### Bandwidth ####
[[group]]
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
[[bench]]
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -0,0 +1,37 @@
[[group]]
name = "download_size"
protocol_latency = 10
bandwidth = 200
upload-size = 2048
[[bench]]
group = "download_size"
download-size = 1024
[[bench]]
group = "download_size"
download-size = 2048
[[bench]]
group = "download_size"
download-size = 4096
[[bench]]
group = "download_size"
download-size = 8192
[[bench]]
group = "download_size"
download-size = 16384
[[bench]]
group = "download_size"
download-size = 32768
[[bench]]
group = "download_size"
download-size = 65536
[[bench]]
group = "download_size"
download-size = 131072

View File

@@ -0,0 +1,25 @@
#### Latency ####
[[group]]
name = "latency"
bandwidth = 1000
[[bench]]
group = "latency"
protocol_latency = 10
[[bench]]
group = "latency"
protocol_latency = 25
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200

View File

@@ -33,7 +33,6 @@ mpz-ole = { workspace = true }
mpz-share-conversion = { workspace = true } mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true } mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true } mpz-memory-core = { workspace = true }
mpz-circuits = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false } ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true } serio = { workspace = true }
@@ -57,9 +56,9 @@ pin-project-lite = { workspace = true }
web-time = { workspace = true } web-time = { workspace = true }
[dev-dependencies] [dev-dependencies]
mpz-ole = { workspace = true, features = ["test-utils"] } mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true } mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true } mpz-ideal-vm = { workspace = true }
cipher-crate = { package = "cipher", version = "0.4" } cipher-crate = { package = "cipher", version = "0.4" }
generic-array = { workspace = true } generic-array = { workspace = true }

View File

@@ -487,7 +487,7 @@ impl RecordLayer {
sent_records.push(Record { sent_records.push(Record {
seq: op.seq, seq: op.seq,
typ: op.typ, typ: op.typ.into(),
plaintext: op.plaintext, plaintext: op.plaintext,
explicit_nonce: op.explicit_nonce, explicit_nonce: op.explicit_nonce,
ciphertext, ciphertext,
@@ -505,7 +505,7 @@ impl RecordLayer {
recv_records.push(Record { recv_records.push(Record {
seq: op.seq, seq: op.seq,
typ: op.typ, typ: op.typ.into(),
plaintext, plaintext,
explicit_nonce: op.explicit_nonce, explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext, ciphertext: op.ciphertext,
@@ -578,7 +578,7 @@ impl RecordLayer {
recv_records.push(Record { recv_records.push(Record {
seq: op.seq, seq: op.seq,
typ: op.typ, typ: op.typ.into(),
plaintext, plaintext,
explicit_nonce: op.explicit_nonce, explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext, ciphertext: op.ciphertext,

View File

@@ -456,9 +456,8 @@ mod tests {
}; };
use mpz_common::context::test_st_context; use mpz_common::context::test_st_context;
use mpz_core::Block; use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_memory_core::{binary::U8, correlated::Delta}; use mpz_memory_core::binary::U8;
use mpz_ot::ideal::cot::ideal_cot;
use mpz_share_conversion::ideal::ideal_share_convert; use mpz_share_conversion::ideal::ideal_share_convert;
use rand::{rngs::StdRng, SeedableRng}; use rand::{rngs::StdRng, SeedableRng};
use rstest::*; use rstest::*;
@@ -574,13 +573,8 @@ mod tests {
} }
fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) { fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) {
let mut rng = StdRng::seed_from_u64(0); let mut vm_0 = IdealVm::new();
let block = Block::random(&mut rng); let mut vm_1 = IdealVm::new();
let (sender, receiver) = ideal_cot(block);
let delta = Delta::new(block);
let mut vm_0 = Garbler::new(sender, [0u8; 16], delta);
let mut vm_1 = Evaluator::new(receiver);
let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap(); let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap();
vm_0.mark_public(key_ref_0).unwrap(); vm_0.mark_public(key_ref_0).unwrap();

View File

@@ -4,14 +4,13 @@ use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader}; use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context; use mpz_common::context::test_mt_context;
use mpz_core::Block; use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta; use mpz_memory_core::correlated::Delta;
use mpz_ot::{ use mpz_ot::{
cot::{DerandCOTReceiver, DerandCOTSender},
ideal::rcot::ideal_rcot, ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender}, rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
}; };
use rand::{rngs::StdRng, Rng, SeedableRng}; use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer; use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore; use tls_client::RootCertStore;
use tls_client_async::bind_client; use tls_client_async::bind_client;
@@ -23,7 +22,6 @@ use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER); const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "expensive"]
async fn mpc_tls_test() { async fn mpc_tls_test() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
@@ -139,14 +137,8 @@ fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a); let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b); let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(Garbler::new( let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
DerandCOTSender::new(rcot_send_a.clone()), let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
rand::rng().random(),
delta_a,
)));
let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new(
rcot_recv_b.clone(),
))));
let leader = MpcTlsLeader::new( let leader = MpcTlsLeader::new(
config.clone(), config.clone(),

View File

@@ -24,7 +24,7 @@ use std::{
}; };
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
use tracing::{debug, debug_span, error, trace, warn, Instrument}; use tracing::{debug, debug_span, trace, warn, Instrument};
use tls_client::ClientConnection; use tls_client::ClientConnection;

View File

@@ -1,5 +1,6 @@
use super::{Backend, BackendError}; use super::{Backend, BackendError};
use crate::{DecryptMode, EncryptMode, Error}; use crate::{DecryptMode, EncryptMode, Error};
#[allow(deprecated)]
use aes_gcm::{ use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead, Payload}, aead::{generic_array::GenericArray, Aead, NewAead, Payload},
Aes128Gcm, Aes128Gcm,
@@ -507,6 +508,7 @@ impl Encrypter {
let mut nonce = [0u8; 12]; let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv); nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(explicit_nonce); nonce[4..].copy_from_slice(explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce); let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap(); let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap();
// ciphertext will have the MAC appended // ciphertext will have the MAC appended
@@ -568,6 +570,7 @@ impl Decrypter {
let mut nonce = [0u8; 12]; let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv); nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(&m.payload.0[0..8]); nonce[4..].copy_from_slice(&m.payload.0[0..8]);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce); let nonce = GenericArray::from_slice(&nonce);
let plaintext = cipher let plaintext = cipher
.decrypt(nonce, aes_payload) .decrypt(nonce, aes_payload)

View File

@@ -12,6 +12,7 @@ workspace = true
[features] [features]
default = ["rayon"] default = ["rayon"]
mozilla-certs = ["tlsn-core/mozilla-certs"]
rayon = ["mpz-zk/rayon", "mpz-garble/rayon"] rayon = ["mpz-zk/rayon", "mpz-garble/rayon"]
web = ["dep:web-spawn"] web = ["dep:web-spawn"]
@@ -29,9 +30,9 @@ serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] } uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true } web-spawn = { workspace = true, optional = true }
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-common = { workspace = true } mpz-common = { workspace = true }
mpz-core = { workspace = true } mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true } mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true } mpz-garble-core = { workspace = true }
mpz-hash = { workspace = true } mpz-hash = { workspace = true }
@@ -40,10 +41,10 @@ mpz-ole = { workspace = true }
mpz-ot = { workspace = true } mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true } mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true } mpz-zk = { workspace = true }
mpz-ideal-vm = { workspace = true }
aes = { workspace = true } aes = { workspace = true }
ctr = { workspace = true } ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
opaque-debug = { workspace = true } opaque-debug = { workspace = true }
rand = { workspace = true } rand = { workspace = true }

3
crates/tlsn/build.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
println!("cargo:rustc-check-cfg=cfg(tlsn_insecure)");
}

View File

@@ -1,119 +0,0 @@
//! TLSNotary protocol config and config utilities.
use once_cell::sync::Lazy;
use semver::Version;
use serde::{Deserialize, Serialize};
pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
// Current version that is running.
pub(crate) static VERSION: Lazy<Version> = Lazy::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
});
/// Protocol configuration to be set up initially by prover and verifier.
#[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)]
#[builder(build_fn(validate = "Self::validate"))]
pub struct ProtocolConfig {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
#[builder(setter(strip_option), default)]
max_sent_records: Option<usize>,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
#[builder(setter(strip_option), default)]
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
#[builder(default = "true")]
defer_decryption_from_start: bool,
/// Network settings.
#[builder(default)]
network: NetworkSetting,
}
impl ProtocolConfigBuilder {
fn validate(&self) -> Result<(), String> {
if self.max_recv_data_online > self.max_recv_data {
return Err(
"max_recv_data_online must be smaller or equal to max_recv_data".to_string(),
);
}
Ok(())
}
}
impl ProtocolConfig {
/// Creates a new builder for `ProtocolConfig`.
pub fn builder() -> ProtocolConfigBuilder {
ProtocolConfigBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> Option<usize> {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
Latency,
}
impl Default for NetworkSetting {
fn default() -> Self {
Self::Latency
}
}

View File

@@ -4,10 +4,10 @@
#![deny(clippy::all)] #![deny(clippy::all)]
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
pub mod config;
pub(crate) mod context; pub(crate) mod context;
pub(crate) mod ghash; pub(crate) mod ghash;
pub(crate) mod map; pub(crate) mod map;
pub(crate) mod mpz;
pub(crate) mod msg; pub(crate) mod msg;
pub(crate) mod mux; pub(crate) mod mux;
pub mod prover; pub mod prover;
@@ -16,7 +16,16 @@ pub(crate) mod transcript_internal;
pub mod verifier; pub mod verifier;
pub use tlsn_attestation as attestation; pub use tlsn_attestation as attestation;
pub use tlsn_core::{connection, hash, transcript}; pub use tlsn_core::{config, connection, hash, transcript, webpki};
use std::sync::LazyLock;
use semver::Version;
// Package version.
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
});
/// The party's role in the TLSN protocol. /// The party's role in the TLSN protocol.
/// ///

View File

@@ -1,7 +1,7 @@
use std::ops::Range; use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8}; use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet; use rangeset::set::RangeSet;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> { pub(crate) struct RangeMap<T> {
@@ -77,7 +77,7 @@ where
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> { pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new(); let mut map = Vec::new();
for idx in idx.iter_ranges() { for idx in idx.iter() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) { let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i, Ok(i) => i,
Err(0) => return None, Err(0) => return None,

233
crates/tlsn/src/mpz.rs Normal file
View File

@@ -0,0 +1,233 @@
use std::sync::Arc;
use mpc_tls::{MpcTlsFollower, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
#[cfg(not(tlsn_insecure))]
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_garble_core::Delta;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Key, Mac},
};
#[cfg(not(tlsn_insecure))]
use mpz_ot::cot::{DerandCOTReceiver, DerandCOTSender};
use mpz_ot::{
chou_orlandi as co, ferret, kos,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use mpz_zk::{Prover, Verifier};
#[cfg(not(tlsn_insecure))]
use rand::Rng;
use tlsn_core::config::tls_commit::mpc::{MpcTlsConfig, NetworkSetting};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::transcript_internal::commit::encoding::{KeyStore, MacStore};
#[cfg(not(tlsn_insecure))]
pub(crate) type ProverMpc =
Garbler<DerandCOTSender<SharedRCOTSender<kos::Sender<co::Receiver>, Block>>>;
#[cfg(tlsn_insecure)]
pub(crate) type ProverMpc = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type ProverZk =
Prover<SharedRCOTReceiver<ferret::Receiver<kos::Receiver<co::Sender>>, bool, Block>>;
#[cfg(tlsn_insecure)]
pub(crate) type ProverZk = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type VerifierMpc =
Evaluator<DerandCOTReceiver<SharedRCOTReceiver<kos::Receiver<co::Sender>, bool, Block>>>;
#[cfg(tlsn_insecure)]
pub(crate) type VerifierMpc = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type VerifierZk =
Verifier<SharedRCOTSender<ferret::Sender<kos::Sender<co::Receiver>>, Block>>;
#[cfg(tlsn_insecure)]
pub(crate) type VerifierZk = mpz_ideal_vm::IdealVm;
pub(crate) struct ProverDeps {
pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
pub(crate) mpc_tls: MpcTlsLeader,
}
pub(crate) fn build_prover_deps(config: MpcTlsConfig, ctx: Context) -> ProverDeps {
let mut rng = rand::rng();
let delta = Delta::new(Block::random(&mut rng));
let base_ot_send = co::Sender::default();
let base_ot_recv = co::Receiver::default();
let rcot_send = kos::Sender::new(
kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
let rcot_recv = ferret::Receiver::new(
ferret::FerretConfig::builder()
.lpn_type(ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_recv,
);
let rcot_send = SharedRCOTSender::new(rcot_send);
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
#[cfg(not(tlsn_insecure))]
let mpc = ProverMpc::new(DerandCOTSender::new(rcot_send.clone()), rng.random(), delta);
#[cfg(tlsn_insecure)]
let mpc = mpz_ideal_vm::IdealVm::new();
#[cfg(not(tlsn_insecure))]
let zk = ProverZk::new(Default::default(), rcot_recv.clone());
#[cfg(tlsn_insecure)]
let zk = mpz_ideal_vm::IdealVm::new();
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
let mpc_tls = MpcTlsLeader::new(
build_mpc_tls_config(config),
ctx,
vm.clone(),
(rcot_send.clone(), rcot_send.clone(), rcot_send),
rcot_recv,
);
ProverDeps { vm, mpc_tls }
}
pub(crate) struct VerifierDeps {
pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
pub(crate) mpc_tls: MpcTlsFollower,
}
pub(crate) fn build_verifier_deps(config: MpcTlsConfig, ctx: Context) -> VerifierDeps {
let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = co::Sender::default();
let base_ot_recv = co::Receiver::default();
let rcot_send = kos::Sender::new(
kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_send = ferret::Sender::new(
ferret::FerretConfig::builder()
.lpn_type(ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_send,
);
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
let rcot_send = SharedRCOTSender::new(rcot_send);
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
#[cfg(not(tlsn_insecure))]
let mpc = VerifierMpc::new(DerandCOTReceiver::new(rcot_recv.clone()));
#[cfg(tlsn_insecure)]
let mpc = mpz_ideal_vm::IdealVm::new();
#[cfg(not(tlsn_insecure))]
let zk = VerifierZk::new(Default::default(), delta, rcot_send.clone());
#[cfg(tlsn_insecure)]
let zk = mpz_ideal_vm::IdealVm::new();
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));
let mpc_tls = MpcTlsFollower::new(
build_mpc_tls_config(config),
ctx,
vm.clone(),
rcot_send,
(rcot_recv.clone(), rcot_recv.clone(), rcot_recv),
);
VerifierDeps { vm, mpc_tls }
}
fn build_mpc_tls_config(config: MpcTlsConfig) -> mpc_tls::Config {
let mut builder = mpc_tls::Config::builder();
builder
.defer_decryption(config.defer_decryption_from_start())
.max_sent(config.max_sent_data())
.max_recv_online(config.max_recv_data_online())
.max_recv(config.max_recv_data());
if let Some(max_sent_records) = config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
pub(crate) fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) {
keys.client_write_key = vm
.translate(keys.client_write_key)
.expect("VM memory should be consistent");
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.expect("VM memory should be consistent");
keys.server_write_key = vm
.translate(keys.server_write_key)
.expect("VM memory should be consistent");
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.expect("VM memory should be consistent");
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.expect("VM memory should be consistent");
}
impl<T> KeyStore for Verifier<T> {
fn delta(&self) -> &Delta {
self.delta()
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
impl<T> MacStore for Prover<T> {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[cfg(tlsn_insecure)]
mod insecure {
use super::*;
use mpz_ideal_vm::IdealVm;
impl KeyStore for IdealVm {
fn delta(&self) -> &Delta {
unimplemented!("encodings not supported in insecure mode")
}
fn get_keys(&self, _data: Vector<U8>) -> Option<&[Key]> {
unimplemented!("encodings not supported in insecure mode")
}
}
impl MacStore for IdealVm {
fn get_macs(&self, _data: Vector<U8>) -> Option<&[Mac]> {
unimplemented!("encodings not supported in insecure mode")
}
}
}

View File

@@ -1,14 +1,25 @@
use semver::Version; use semver::Version;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::config::ProtocolConfig; use tlsn_core::{
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
connection::{HandshakeData, ServerName},
transcript::PartialTranscript,
};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub(crate) struct SetupRequest { pub(crate) struct TlsCommitRequestMsg {
pub(crate) config: ProtocolConfig, pub(crate) request: TlsCommitRequest,
pub(crate) version: Version, pub(crate) version: Version,
} }
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ProveRequestMsg {
pub(crate) request: ProveRequest,
pub(crate) handshake: Option<(ServerName, HandshakeData)>,
pub(crate) transcript: Option<PartialTranscript>,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Response { pub(crate) struct Response {
pub(crate) result: Result<(), RejectionReason>, pub(crate) result: Result<(), RejectionReason>,

View File

@@ -1,63 +1,45 @@
//! Prover. //! Prover.
mod config;
mod error; mod error;
mod future; mod future;
mod prove; mod prove;
pub mod state; pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError; pub use error::ProverError;
pub use future::ProverFuture; pub use future::ProverFuture;
use rustls_pki_types::CertificateDer; pub use tlsn_core::ProverOutput;
pub use tlsn_core::{
ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProveRequest, ProverOutput,
};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use webpki::anchor_from_trusted_cert;
use crate::{ use crate::{
Role, Role,
context::build_mt_context, context::build_mt_context,
msg::{Response, SetupRequest}, mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux, mux::attach_mux,
tag::verify_tags, tag::verify_tags,
}; };
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys}; use mpc_tls::LeaderCtrl;
use rand::Rng; use mpz_vm_core::prelude::*;
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc; use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName}; use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client}; use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{ use tlsn_core::{
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, TlsCommitProtocolConfig},
},
connection::{HandshakeData, ServerName}, connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript}, transcript::{TlsTranscript, Transcript},
}; };
use tlsn_deap::Deap; use webpki::anchor_from_trusted_cert;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument}; use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
>;
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
bool,
mpz_core::Block,
>;
pub(crate) type Mpc =
mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
/// A prover instance. /// A prover instance.
#[derive(Debug)] #[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> { pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -81,19 +63,21 @@ impl Prover<state::Initialized> {
} }
} }
/// Sets up the prover. /// Starts the TLS commitment protocol.
/// ///
/// This performs all MPC setup prior to establishing the connection to the /// This initiates the TLS commitment protocol, including performing any
/// application server. /// necessary preprocessing operations.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier. /// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)] #[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self, self,
config: TlsCommitConfig,
socket: S, socket: S,
) -> Result<Prover<state::Setup>, ProverError> { ) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover); let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone()); let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?; let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
@@ -102,9 +86,9 @@ impl Prover<state::Initialized> {
mux_fut mux_fut
.poll_with(async { .poll_with(async {
ctx.io_mut() ctx.io_mut()
.send(SetupRequest { .send(TlsCommitRequestMsg {
config: self.config.protocol_config().clone(), request: config.to_request(),
version: crate::config::VERSION.clone(), version: crate::VERSION.clone(),
}) })
.await?; .await?;
@@ -116,12 +100,16 @@ impl Prover<state::Initialized> {
}) })
.await?; .await?;
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx); let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = config.protocol().clone() else {
unreachable!("only MPC TLS is supported");
};
let ProverDeps { vm, mut mpc_tls } = build_prover_deps(mpc_tls_config, ctx);
// Allocate resources for MPC-TLS in the VM. // Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?; let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked"); let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?; translate_keys(&mut keys, &vm_lock);
drop(vm_lock); drop(vm_lock);
debug!("setting up mpc-tls"); debug!("setting up mpc-tls");
@@ -133,7 +121,7 @@ impl Prover<state::Initialized> {
Ok(Prover { Ok(Prover {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Setup { state: state::CommitAccepted {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
mpc_tls, mpc_tls,
@@ -144,21 +132,24 @@ impl Prover<state::Initialized> {
} }
} }
impl Prover<state::Setup> { impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket. /// Connects to the server using the provided socket.
/// ///
/// Returns a handle to the TLS connection, a future which returns the /// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed. /// prover once the connection is closed and the TLS transcript is
/// committed.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server. /// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)] #[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self, self,
config: TlsClientConfig,
socket: S, socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> { ) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::Setup { let state::CommitAccepted {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mpc_tls, mpc_tls,
@@ -169,12 +160,13 @@ impl Prover<state::Setup> {
let (mpc_ctrl, mpc_fut) = mpc_tls.run(); let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = self.config.server_name(); let ServerName::Dns(server_name) = config.server_name();
let server_name = let server_name =
TlsServerName::try_from(server_name.as_ref()).expect("name was validated"); TlsServerName::try_from(server_name.as_ref()).expect("name was validated");
let root_store = if let Some(root_store) = self.config.tls_config().root_store() { let root_store = tls_client::RootCertStore {
let roots = root_store roots: config
.root_store()
.roots .roots
.iter() .iter()
.map(|cert| { .map(|cert| {
@@ -183,20 +175,15 @@ impl Prover<state::Setup> {
.map(|anchor| anchor.to_owned()) .map(|anchor| anchor.to_owned())
.map_err(ProverError::config) .map_err(ProverError::config)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?,
tls_client::RootCertStore { roots }
} else {
tls_client::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
}; };
let config = tls_client::ClientConfig::builder() let rustls_config = tls_client::ClientConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_root_certificates(root_store); .with_root_certificates(root_store);
let config = if let Some((cert, key)) = self.config.tls_config().client_auth() { let rustls_config = if let Some((cert, key)) = config.client_auth() {
config rustls_config
.with_single_cert( .with_single_cert(
cert.iter() cert.iter()
.map(|cert| tls_client::Certificate(cert.0.clone())) .map(|cert| tls_client::Certificate(cert.0.clone()))
@@ -205,12 +192,15 @@ impl Prover<state::Setup> {
) )
.map_err(ProverError::config)? .map_err(ProverError::config)?
} else { } else {
config.with_no_client_auth() rustls_config.with_no_client_auth()
}; };
let client = let client = ClientConnection::new(
ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name) Arc::new(rustls_config),
.map_err(ProverError::config)?; Box::new(mpc_ctrl.clone()),
server_name,
)
.map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client); let (conn, conn_fut) = bind_client(socket, client);
@@ -284,6 +274,7 @@ impl Prover<state::Setup> {
mux_fut, mux_fut,
ctx, ctx,
vm, vm,
server_name: config.server_name().clone(),
keys, keys,
tls_transcript, tls_transcript,
transcript, transcript,
@@ -326,40 +317,42 @@ impl Prover<state::Committed> {
ctx, ctx,
vm, vm,
keys, keys,
server_name,
tls_transcript, tls_transcript,
transcript, transcript,
.. ..
} = &mut self.state; } = &mut self.state;
let request = ProveRequest { let handshake = config.server_identity().then(|| {
handshake: config.server_identity().then(|| { (
( server_name.clone(),
self.config.server_name().clone(), HandshakeData {
HandshakeData { certs: tls_transcript
certs: tls_transcript .server_cert_chain()
.server_cert_chain() .expect("server cert chain is present")
.expect("server cert chain is present") .to_vec(),
.to_vec(), sig: tls_transcript
sig: tls_transcript .server_signature()
.server_signature() .expect("server signature is present")
.expect("server signature is present") .clone(),
.clone(), binding: tls_transcript.certificate_binding().clone(),
binding: tls_transcript.certificate_binding().clone(), },
}, )
) });
}),
transcript: config let partial_transcript = config
.reveal() .reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())), .map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone()));
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
let msg = ProveRequestMsg {
request: config.to_request(),
handshake,
transcript: partial_transcript,
}; };
let output = mux_fut let output = mux_fut
.poll_with(async { .poll_with(async {
ctx.io_mut() ctx.io_mut().send(msg).await.map_err(ProverError::from)?;
.send(request)
.await
.map_err(ProverError::from)?;
ctx.io_mut().expect_next::<Response>().await?.result?; ctx.io_mut().expect_next::<Response>().await?.result?;
@@ -387,53 +380,6 @@ impl Prover<state::Committed> {
} }
} }
fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
let mut rng = rand::rng();
let delta = Delta::new(Block::random(&mut rng));
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(
mpz_ot::kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_recv =
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
let rcot_recv = mpz_ot::ferret::Receiver::new(
mpz_ot::ferret::FerretConfig::builder()
.lpn_type(mpz_ot::ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_recv,
);
let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
let mpc = Mpc::new(
mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
rng.random(),
delta,
);
let zk = Zk::new(ZkProverConfig::default(), rcot_recv.clone());
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
(
vm.clone(),
MpcTlsLeader::new(
config.build_mpc_tls_config(),
ctx,
vm,
(rcot_send.clone(), rcot_send.clone(), rcot_send),
rcot_recv,
),
)
}
/// A controller for the prover. /// A controller for the prover.
#[derive(Clone)] #[derive(Clone)]
pub struct ProverControl { pub struct ProverControl {
@@ -459,24 +405,3 @@ impl ProverControl {
.map_err(ProverError::from) .map_err(ProverError::from)
} }
} }
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(ProverError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(ProverError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(ProverError::mpc)?;
Ok(())
}

View File

@@ -1,144 +0,0 @@
use mpc_tls::Config;
use serde::{Deserialize, Serialize};
use tlsn_core::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use crate::config::{NetworkSetting, ProtocolConfig};
/// Configuration for the prover.
#[derive(Debug, Clone, derive_builder::Builder, Serialize, Deserialize)]
pub struct ProverConfig {
/// The server DNS name.
#[builder(setter(into))]
server_name: ServerName,
/// Protocol configuration to be checked with the verifier.
protocol_config: ProtocolConfig,
/// TLS configuration.
#[builder(default)]
tls_config: TlsConfig,
}
impl ProverConfig {
/// Creates a new builder for `ProverConfig`.
pub fn builder() -> ProverConfigBuilder {
ProverConfigBuilder::default()
}
/// Returns the server DNS name.
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Returns the protocol configuration.
pub fn protocol_config(&self) -> &ProtocolConfig {
&self.protocol_config
}
/// Returns the TLS configuration.
pub fn tls_config(&self) -> &TlsConfig {
&self.tls_config
}
pub(crate) fn build_mpc_tls_config(&self) -> Config {
let mut builder = Config::builder();
builder
.defer_decryption(self.protocol_config.defer_decryption_from_start())
.max_sent(self.protocol_config.max_sent_data())
.max_recv_online(self.protocol_config.max_recv_data_online())
.max_recv(self.protocol_config.max_recv_data());
if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = self.protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}
/// Configuration for the prover's TLS connection.
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
/// Root certificates.
root_store: Option<RootCertStore>,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfig {
/// Creates a new builder for `TlsConfig`.
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
pub(crate) fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> &Option<(Vec<CertificateDer>, PrivateKeyDer)> {
&self.client_auth
}
}
/// Builder for [`TlsConfig`].
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
/// Sets the root certificates to use for verifying the server's
/// certificate.
pub fn root_store(&mut self, store: RootCertStore) -> &mut Self {
self.root_store = Some(store);
self
}
/// Sets a DER-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(&mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> &mut Self {
self.client_auth = Some(cert_key);
self
}
/// Builds the TLS configuration.
pub fn build(self) -> Result<TlsConfig, TlsConfigError> {
Ok(TlsConfig {
root_store: self.root_store,
client_auth: self.client_auth,
})
}
}
/// TLS configuration error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error")]
enum ErrorRepr {}

View File

@@ -4,7 +4,7 @@ use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError; use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Prover`](crate::Prover). /// Error for [`Prover`](crate::prover::Prover).
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub struct ProverError { pub struct ProverError {
kind: ErrorKind, kind: ErrorKind,

View File

@@ -2,9 +2,10 @@ use mpc_tls::SessionKeys;
use mpz_common::Context; use mpz_common::Context;
use mpz_memory_core::binary::Binary; use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm; use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut}; use rangeset::set::RangeSet;
use tlsn_core::{ use tlsn_core::{
ProveConfig, ProverOutput, ProverOutput,
config::prove::ProveConfig,
transcript::{ transcript::{
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret, ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
}, },

View File

@@ -4,13 +4,16 @@ use std::sync::Arc;
use mpc_tls::{MpcTlsLeader, SessionKeys}; use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context; use mpz_common::Context;
use tlsn_core::transcript::{TlsTranscript, Transcript}; use tlsn_core::{
connection::ServerName,
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap; use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::{ use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture}, mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
}; };
/// Entry state /// Entry state
@@ -18,23 +21,25 @@ pub struct Initialized;
opaque_debug::implement!(Initialized); opaque_debug::implement!(Initialized);
/// State after MPC setup has completed. /// State after the verifier has accepted the proposed TLS commitment protocol
pub struct Setup { /// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader, pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>, pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
} }
opaque_debug::implement!(Setup); opaque_debug::implement!(CommitAccepted);
/// State after the TLS connection has been committed and closed. /// State after the TLS transcript has been committed.
pub struct Committed { pub struct Committed {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) vm: Zk, pub(crate) vm: ProverZk,
pub(crate) server_name: ServerName,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript, pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript, pub(crate) transcript: Transcript,
@@ -46,12 +51,12 @@ opaque_debug::implement!(Committed);
pub trait ProverState: sealed::Sealed {} pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {} impl ProverState for Initialized {}
impl ProverState for Setup {} impl ProverState for CommitAccepted {}
impl ProverState for Committed {} impl ProverState for Committed {}
mod sealed { mod sealed {
pub trait Sealed {} pub trait Sealed {}
impl Sealed for super::Initialized {} impl Sealed for super::Initialized {}
impl Sealed for super::Setup {} impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {} impl Sealed for super::Committed {}
} }

View File

@@ -112,7 +112,7 @@ impl TagProof {
.map_err(TagProofError::vm)? .map_err(TagProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?; .ok_or_else(|| ErrorRepr::NotDecoded)?;
let aad = make_tls12_aad(rec.seq, rec.typ, vers, rec.ciphertext.len()); let aad = make_tls12_aad(rec.seq, rec.typ.into(), vers, rec.ciphertext.len());
let ghash_tag = ghash(aad.as_ref(), &rec.ciphertext, &mac_key); let ghash_tag = ghash(aad.as_ref(), &rec.ciphertext, &mac_key);

View File

@@ -5,14 +5,14 @@ use ctr::{
Ctr32BE, Ctr32BE,
cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}, cipher::{KeyIvInit, StreamCipher, StreamCipherSeek},
}; };
use mpz_circuits::circuits::{AES128, xor}; use mpz_circuits::{AES128, circuits::xor};
use mpz_core::bitvec::BitVec; use mpz_core::bitvec::BitVec;
use mpz_memory_core::{ use mpz_memory_core::{
Array, DecodeFutureTyped, MemoryExt, Vector, ViewExt, Array, DecodeFutureTyped, MemoryExt, Vector, ViewExt,
binary::{Binary, U8}, binary::{Binary, U8},
}; };
use mpz_vm_core::{Call, CallableExt, Vm}; use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{Difference, RangeSet, Union}; use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
use tlsn_core::transcript::Record; use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap; use crate::transcript_internal::ReferenceMap;
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
commit.clone() commit.clone()
} else { } else {
// The plaintext is only partially revealed, so we need to authenticate in ZK. // The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal) commit.union(reveal).into_set()
}; };
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?; let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
vm.commit(*slice).map_err(PlaintextAuthError::vm)?; vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
} }
} else { } else {
let private = commit.difference(reveal); let private = commit.difference(reveal).into_set();
for (_, slice) in plaintext_refs for (_, slice) in plaintext_refs
.index(&private) .index(&private)
.expect("all ranges are allocated") .expect("all ranges are allocated")
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
commit.clone() commit.clone()
} else { } else {
// The plaintext is only partially revealed, so we need to authenticate in ZK. // The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal) commit.union(reveal).into_set()
}; };
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?; let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
ciphertext, ciphertext,
}) })
} else { } else {
let private = commit.difference(reveal); let private = commit.difference(reveal).into_set();
for (_, slice) in plaintext_refs for (_, slice) in plaintext_refs
.index(&private) .index(&private)
.expect("all ranges are allocated") .expect("all ranges are allocated")
@@ -175,15 +175,13 @@ fn alloc_plaintext(
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?; let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0; let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map( Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
move |range| { let chunk = plaintext
let chunk = plaintext .get(pos..pos + range.len())
.get(pos..pos + range.len()) .expect("length was checked");
.expect("length was checked"); pos += range.len();
pos += range.len(); (range.start, chunk)
(range.start, chunk) })))
},
)))
} }
fn alloc_ciphertext<'a>( fn alloc_ciphertext<'a>(
@@ -212,15 +210,13 @@ fn alloc_ciphertext<'a>(
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?; let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0; let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map( Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
move |range| { let chunk = ciphertext
let chunk = ciphertext .get(pos..pos + range.len())
.get(pos..pos + range.len()) .expect("length was checked");
.expect("length was checked"); pos += range.len();
pos += range.len(); (range.start, chunk)
(range.start, chunk) })))
},
)))
} }
fn alloc_keystream<'a>( fn alloc_keystream<'a>(
@@ -233,7 +229,7 @@ fn alloc_keystream<'a>(
let mut keystream = Vec::new(); let mut keystream = Vec::new();
let mut pos = 0; let mut pos = 0;
let mut range_iter = ranges.iter_ranges(); let mut range_iter = ranges.iter();
let mut current_range = range_iter.next(); let mut current_range = range_iter.next();
for record in records { for record in records {
let mut explicit_nonce = None; let mut explicit_nonce = None;
@@ -508,7 +504,7 @@ mod tests {
for record in records { for record in records {
let mut record_keystream = vec![0u8; record.len]; let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream); aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter_ranges() { for mut range in ranges.iter() {
range.start = range.start.max(pos); range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len); range.end = range.end.min(pos + record.len);
if range.start < range.end { if range.start < range.end {

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
correlated::{Delta, Key, Mac}, correlated::{Delta, Key, Mac},
}; };
use rand::Rng; use rand::Rng;
use rangeset::RangeSet; use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{ use tlsn_core::{
@@ -177,26 +177,10 @@ pub(crate) trait KeyStore {
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>; fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
} }
impl KeyStore for crate::verifier::Zk {
fn delta(&self) -> &Delta {
crate::verifier::Zk::delta(self)
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
pub(crate) trait MacStore { pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>; fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
} }
impl MacStore for crate::prover::Zk {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[derive(Debug)] #[derive(Debug)]
struct Provider { struct Provider {
sent: RangeMap<EncodingSlice>, sent: RangeMap<EncodingSlice>,

View File

@@ -3,13 +3,13 @@
use std::collections::HashMap; use std::collections::HashMap;
use mpz_core::bitvec::BitVec; use mpz_core::bitvec::BitVec;
use mpz_hash::{blake3::Blake3, sha256::Sha256}; use mpz_hash::{blake3::Blake3, keccak256::Keccak256, sha256::Sha256};
use mpz_memory_core::{ use mpz_memory_core::{
DecodeFutureTyped, MemoryExt, Vector, DecodeFutureTyped, MemoryExt, Vector,
binary::{Binary, U8}, binary::{Binary, U8},
}; };
use mpz_vm_core::{Vm, VmError, prelude::*}; use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet; use rangeset::set::RangeSet;
use tlsn_core::{ use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash}, hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{ transcript::{
@@ -111,6 +111,7 @@ pub(crate) fn verify_hash(
enum Hasher { enum Hasher {
Sha256(Sha256), Sha256(Sha256),
Blake3(Blake3), Blake3(Blake3),
Keccak256(Keccak256),
} }
/// Commit plaintext hashes of the transcript. /// Commit plaintext hashes of the transcript.
@@ -154,7 +155,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv, Direction::Received => &refs.recv,
}; };
for range in idx.iter_ranges() { for range in idx.iter() {
hasher.update(&refs.get(range).expect("plaintext refs are valid")); hasher.update(&refs.get(range).expect("plaintext refs are valid"));
} }
@@ -175,7 +176,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv, Direction::Received => &refs.recv,
}; };
for range in idx.iter_ranges() { for range in idx.iter() {
hasher hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid")) .update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?; .map_err(HashCommitError::hasher)?;
@@ -185,6 +186,32 @@ fn hash_commit_inner(
.map_err(HashCommitError::hasher)?; .map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)? hasher.finalize(vm).map_err(HashCommitError::hasher)?
} }
HashAlgId::KECCAK256 => {
let mut hasher = if let Some(Hasher::Keccak256(hasher)) = hashers.get(&alg).cloned()
{
hasher
} else {
let hasher = Keccak256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Keccak256(hasher.clone()));
hasher
};
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
}
hasher
.update(vm, &blinder)
.map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => { alg => {
return Err(HashCommitError::unsupported_alg(alg)); return Err(HashCommitError::unsupported_alg(alg));
} }

View File

@@ -1,55 +1,37 @@
//! Verifier. //! Verifier.
pub(crate) mod config;
mod error; mod error;
pub mod state; pub mod state;
mod verify; mod verify;
use std::sync::Arc; use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError; pub use error::VerifierError;
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier}; pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
use crate::{ use crate::{
Role, Role,
config::ProtocolConfig,
context::build_mt_context, context::build_mt_context,
msg::{Response, SetupRequest}, mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux, mux::attach_mux,
tag::verify_tags, tag::verify_tags,
}; };
use futures::{AsyncRead, AsyncWrite, TryFutureExt}; use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*; use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::{SinkExt, stream::IoStreamExt}; use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{ use tlsn_core::{
ProveRequest, config::{
prove::ProveRequest,
tls_commit::{TlsCommitProtocolConfig, TlsCommitRequest},
verifier::VerifierConfig,
},
connection::{ConnectionInfo, ServerName}, connection::{ConnectionInfo, ServerName},
transcript::TlsTranscript, transcript::TlsTranscript,
}; };
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument}; use tracing::{Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
>;
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>,
bool,
mpz_core::Block,
>;
pub(crate) type Mpc =
mpz_garble::protocol::semihonest::Evaluator<mpz_ot::cot::DerandCOTReceiver<RCOTReceiver>>;
pub(crate) type Zk = mpz_zk::Verifier<RCOTSender>;
/// Information about the TLS session. /// Information about the TLS session.
#[derive(Debug)] #[derive(Debug)]
pub struct SessionInfo { pub struct SessionInfo {
@@ -77,30 +59,31 @@ impl Verifier<state::Initialized> {
} }
} }
/// Sets up the verifier. /// Starts the TLS commitment protocol.
/// ///
/// This performs all MPC setup. /// This initiates the TLS commitment protocol, receiving the prover's
/// configuration and providing the opportunity to accept or reject it.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `socket` - The socket to the prover. /// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>( pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self, self,
socket: S, socket: S,
) -> Result<Verifier<state::Config>, VerifierError> { ) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier); let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone()); let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?; let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Receives protocol configuration from prover to perform compatibility check. // Receives protocol configuration from prover to perform compatibility check.
let SetupRequest { config, version } = let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?; mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
if version != *crate::config::VERSION { if version != *crate::VERSION {
let msg = format!( let msg = format!(
"prover version does not match with verifier: {version} != {}", "prover version does not match with verifier: {version} != {}",
*crate::config::VERSION *crate::VERSION
); );
mux_fut mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone())))) .poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
@@ -118,40 +101,44 @@ impl Verifier<state::Initialized> {
Ok(Verifier { Ok(Verifier {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Config { state: state::CommitStart {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
ctx, ctx,
config, request,
}, },
}) })
} }
} }
impl Verifier<state::Config> { impl Verifier<state::CommitStart> {
/// Returns the proposed protocol configuration. /// Returns the TLS commitment request.
pub fn config(&self) -> &ProtocolConfig { pub fn request(&self) -> &TlsCommitRequest {
&self.state.config &self.state.request
} }
/// Accepts the proposed protocol configuration. /// Accepts the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::Setup>, VerifierError> { pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::Config { let state::CommitStart {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
config, request,
} = self.state; } = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?; mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &config, ctx); let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol().clone() else {
unreachable!("only MPC TLS is supported");
};
let VerifierDeps { vm, mut mpc_tls } = build_verifier_deps(mpc_tls_config, ctx);
// Allocate resources for MPC-TLS in the VM. // Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?; let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked"); let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?; translate_keys(&mut keys, &vm_lock);
drop(vm_lock); drop(vm_lock);
debug!("setting up mpc-tls"); debug!("setting up mpc-tls");
@@ -163,7 +150,7 @@ impl Verifier<state::Config> {
Ok(Verifier { Ok(Verifier {
config: self.config, config: self.config,
span: self.span, span: self.span,
state: state::Setup { state: state::CommitAccepted {
mux_ctrl, mux_ctrl,
mux_fut, mux_fut,
mpc_tls, mpc_tls,
@@ -176,7 +163,7 @@ impl Verifier<state::Config> {
/// Rejects the proposed protocol configuration. /// Rejects the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> { pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> {
let state::Config { let state::CommitStart {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mut ctx, mut ctx,
@@ -197,11 +184,11 @@ impl Verifier<state::Config> {
} }
} }
impl Verifier<state::Setup> { impl Verifier<state::CommitAccepted> {
/// Runs the verifier until the TLS connection is closed. /// Runs the verifier until the TLS connection is closed.
#[instrument(parent = &self.span, level = "info", skip_all, err)] #[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> { pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Setup { let state::CommitAccepted {
mux_ctrl, mux_ctrl,
mut mux_fut, mut mux_fut,
mpc_tls, mpc_tls,
@@ -287,7 +274,11 @@ impl Verifier<state::Committed> {
tls_transcript, tls_transcript,
} = self.state; } = self.state;
let request = mux_fut let ProveRequestMsg {
request,
handshake,
transcript,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from)) .poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?; .await?;
@@ -302,6 +293,8 @@ impl Verifier<state::Committed> {
keys, keys,
tls_transcript, tls_transcript,
request, request,
handshake,
transcript,
}, },
}) })
} }
@@ -341,15 +334,14 @@ impl Verifier<state::Verify> {
keys, keys,
tls_transcript, tls_transcript,
request, request,
handshake,
transcript,
} = self.state; } = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?; mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
let cert_verifier = if let Some(root_store) = self.config.root_store() { let cert_verifier =
ServerCertVerifier::new(root_store).map_err(VerifierError::config)? ServerCertVerifier::new(self.config.root_store()).map_err(VerifierError::config)?;
} else {
ServerCertVerifier::mozilla()
};
let output = mux_fut let output = mux_fut
.poll_with(verify::verify( .poll_with(verify::verify(
@@ -359,6 +351,8 @@ impl Verifier<state::Verify> {
&cert_verifier, &cert_verifier,
&tls_transcript, &tls_transcript,
request, request,
handshake,
transcript,
)) ))
.await?; .await?;
@@ -412,74 +406,3 @@ impl Verifier<state::Verify> {
}) })
} }
} }
fn build_mpc_tls(
config: &VerifierConfig,
protocol_config: &ProtocolConfig,
ctx: Context,
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(
mpz_ot::kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_send = mpz_ot::ferret::Sender::new(
mpz_ot::ferret::FerretConfig::builder()
.lpn_type(mpz_ot::ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_send,
);
let rcot_recv =
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
let mpc = Mpc::new(mpz_ot::cot::DerandCOTReceiver::new(rcot_recv.clone()));
let zk = Zk::new(ZkVerifierConfig::default(), delta, rcot_send.clone());
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));
(
vm.clone(),
MpcTlsFollower::new(
config.build_mpc_tls_config(protocol_config),
ctx,
vm,
rcot_send,
(rcot_recv.clone(), rcot_recv.clone(), rcot_recv),
),
)
}
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(
keys: &mut SessionKeys,
vm: &Deap<Mpc, Zk>,
) -> Result<(), VerifierError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(VerifierError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(VerifierError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(VerifierError::mpc)?;
Ok(())
}

View File

@@ -1,57 +0,0 @@
use std::fmt::{Debug, Formatter, Result};
use mpc_tls::Config;
use serde::{Deserialize, Serialize};
use tlsn_core::webpki::RootCertStore;
use crate::config::{NetworkSetting, ProtocolConfig};
/// Configuration for the [`Verifier`](crate::tls::Verifier).
#[allow(missing_docs)]
#[derive(derive_builder::Builder, Serialize, Deserialize)]
#[builder(pattern = "owned")]
pub struct VerifierConfig {
#[builder(default, setter(strip_option))]
root_store: Option<RootCertStore>,
}
impl Debug for VerifierConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("VerifierConfig").finish_non_exhaustive()
}
}
impl VerifierConfig {
/// Creates a new configuration builder.
pub fn builder() -> VerifierConfigBuilder {
VerifierConfigBuilder::default()
}
/// Returns the root certificate store.
pub fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
pub(crate) fn build_mpc_tls_config(&self, protocol_config: &ProtocolConfig) -> Config {
let mut builder = Config::builder();
builder
.max_sent(protocol_config.max_sent_data())
.max_recv_online(protocol_config.max_recv_data_online())
.max_recv(protocol_config.max_recv_data());
if let Some(max_sent_records) = protocol_config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}

View File

@@ -4,7 +4,7 @@ use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError; use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Verifier`](crate::Verifier). /// Error for [`Verifier`](crate::verifier::Verifier).
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub struct VerifierError { pub struct VerifierError {
kind: ErrorKind, kind: ErrorKind,

View File

@@ -2,17 +2,18 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::mux::{MuxControl, MuxFuture};
config::ProtocolConfig,
mux::{MuxControl, MuxFuture},
};
use mpc_tls::{MpcTlsFollower, SessionKeys}; use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context; use mpz_common::Context;
use tlsn_core::{ProveRequest, transcript::TlsTranscript}; use tlsn_core::{
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
connection::{HandshakeData, ServerName},
transcript::{PartialTranscript, TlsTranscript},
};
use tlsn_deap::Deap; use tlsn_deap::Deap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::verifier::{Mpc, Zk}; use crate::mpz::{VerifierMpc, VerifierZk};
/// TLS Verifier state. /// TLS Verifier state.
pub trait VerifierState: sealed::Sealed {} pub trait VerifierState: sealed::Sealed {}
@@ -23,32 +24,33 @@ pub struct Initialized;
opaque_debug::implement!(Initialized); opaque_debug::implement!(Initialized);
/// State after receiving protocol configuration from the prover. /// State after receiving protocol configuration from the prover.
pub struct Config { pub struct CommitStart {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) config: ProtocolConfig, pub(crate) request: TlsCommitRequest,
} }
opaque_debug::implement!(Config); opaque_debug::implement!(CommitStart);
/// State after setup has completed. /// State after accepting the proposed TLS commitment protocol configuration and
pub struct Setup { /// performing preprocessing.
pub struct CommitAccepted {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower, pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>, pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
} }
opaque_debug::implement!(Setup); opaque_debug::implement!(CommitAccepted);
/// State after the TLS connection has been closed. /// State after the TLS transcript has been committed.
pub struct Committed { pub struct Committed {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) vm: Zk, pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript, pub(crate) tls_transcript: TlsTranscript,
} }
@@ -60,25 +62,27 @@ pub struct Verify {
pub(crate) mux_ctrl: MuxControl, pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture, pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context, pub(crate) ctx: Context,
pub(crate) vm: Zk, pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys, pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript, pub(crate) tls_transcript: TlsTranscript,
pub(crate) request: ProveRequest, pub(crate) request: ProveRequest,
pub(crate) handshake: Option<(ServerName, HandshakeData)>,
pub(crate) transcript: Option<PartialTranscript>,
} }
opaque_debug::implement!(Verify); opaque_debug::implement!(Verify);
impl VerifierState for Initialized {} impl VerifierState for Initialized {}
impl VerifierState for Config {} impl VerifierState for CommitStart {}
impl VerifierState for Setup {} impl VerifierState for CommitAccepted {}
impl VerifierState for Committed {} impl VerifierState for Committed {}
impl VerifierState for Verify {} impl VerifierState for Verify {}
mod sealed { mod sealed {
pub trait Sealed {} pub trait Sealed {}
impl Sealed for super::Initialized {} impl Sealed for super::Initialized {}
impl Sealed for super::Config {} impl Sealed for super::CommitStart {}
impl Sealed for super::Setup {} impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {} impl Sealed for super::Committed {}
impl Sealed for super::Verify {} impl Sealed for super::Verify {}
} }

View File

@@ -2,9 +2,11 @@ use mpc_tls::SessionKeys;
use mpz_common::Context; use mpz_common::Context;
use mpz_memory_core::binary::Binary; use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm; use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut}; use rangeset::set::RangeSet;
use tlsn_core::{ use tlsn_core::{
ProveRequest, VerifierOutput, VerifierOutput,
config::prove::ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{ transcript::{
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment, ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
}, },
@@ -23,6 +25,7 @@ use crate::{
verifier::VerifierError, verifier::VerifierError,
}; };
#[allow(clippy::too_many_arguments)]
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>( pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context, ctx: &mut Context,
vm: &mut T, vm: &mut T,
@@ -30,18 +33,19 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
cert_verifier: &ServerCertVerifier, cert_verifier: &ServerCertVerifier,
tls_transcript: &TlsTranscript, tls_transcript: &TlsTranscript,
request: ProveRequest, request: ProveRequest,
handshake: Option<(ServerName, HandshakeData)>,
transcript: Option<PartialTranscript>,
) -> Result<VerifierOutput, VerifierError> { ) -> Result<VerifierOutput, VerifierError> {
let ProveRequest {
handshake,
transcript,
transcript_commit,
} = request;
let ciphertext_sent = collect_ciphertext(tls_transcript.sent()); let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
let ciphertext_recv = collect_ciphertext(tls_transcript.recv()); let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
let has_reveal = transcript.is_some(); let transcript = if let Some((auth_sent, auth_recv)) = request.reveal() {
let transcript = if let Some(transcript) = transcript { let Some(transcript) = transcript else {
return Err(VerifierError::verify(
"prover requested to reveal data but did not send transcript",
));
};
if transcript.len_sent() != ciphertext_sent.len() if transcript.len_sent() != ciphertext_sent.len()
|| transcript.len_received() != ciphertext_recv.len() || transcript.len_received() != ciphertext_recv.len()
{ {
@@ -50,6 +54,18 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
)); ));
} }
if transcript.sent_authed() != auth_sent {
return Err(VerifierError::verify(
"prover sent transcript with incorrect sent authed data",
));
}
if transcript.received_authed() != auth_recv {
return Err(VerifierError::verify(
"prover sent transcript with incorrect received authed data",
));
}
transcript transcript
} else { } else {
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len()) PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
@@ -71,7 +87,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
}; };
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default()); let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = transcript_commit.as_ref() { if let Some(commit_config) = request.transcript_commit() {
commit_config commit_config
.iter_hash() .iter_hash()
.for_each(|(direction, idx, _)| match direction { .for_each(|(direction, idx, _)| match direction {
@@ -121,7 +137,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
let mut transcript_commitments = Vec::new(); let mut transcript_commitments = Vec::new();
let mut hash_commitments = None; let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit.as_ref() if let Some(commit_config) = request.transcript_commit()
&& commit_config.has_hash() && commit_config.has_hash()
{ {
hash_commitments = Some( hash_commitments = Some(
@@ -136,7 +152,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
recv_proof.verify().map_err(VerifierError::verify)?; recv_proof.verify().map_err(VerifierError::verify)?;
let mut encoder_secret = None; let mut encoder_secret = None;
if let Some(commit_config) = transcript_commit if let Some(commit_config) = request.transcript_commit()
&& let Some((sent, recv)) = commit_config.encoding() && let Some((sent, recv)) = commit_config.encoding()
{ {
let sent_map = transcript_refs let sent_map = transcript_refs
@@ -161,7 +177,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
Ok(VerifierOutput { Ok(VerifierOutput {
server_name, server_name,
transcript: has_reveal.then_some(transcript), transcript: request.reveal().is_some().then_some(transcript),
encoder_secret, encoder_secret,
transcript_commitments, transcript_commitments,
}) })

View File

@@ -1,15 +1,22 @@
use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet; use rangeset::set::RangeSet;
use tlsn::{ use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore}, config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
verifier::VerifierConfig,
},
connection::ServerName, connection::ServerName,
hash::{HashAlgId, HashProvider}, hash::{HashAlgId, HashProvider},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig}, prover::Prover,
transcript::{ transcript::{
Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment, Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment,
TranscriptCommitmentKind, TranscriptSecret, TranscriptCommitmentKind, TranscriptSecret,
}, },
verifier::{Verifier, VerifierConfig, VerifierOutput}, verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
}; };
use tlsn_core::ProverOutput; use tlsn_core::ProverOutput;
use tlsn_server_fixture::bind; use tlsn_server_fixture::bind;
@@ -44,19 +51,11 @@ async fn test() {
assert_eq!(server_name.as_str(), SERVER_DOMAIN); assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete()); assert!(!partial_transcript.is_complete());
assert_eq!( assert_eq!(
partial_transcript partial_transcript.sent_authed().iter().next().unwrap(),
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
0..10 0..10
); );
assert_eq!( assert_eq!(
partial_transcript partial_transcript.received_authed().iter().next().unwrap(),
.received_authed()
.iter_ranges()
.next()
.unwrap(),
0..10 0..10
); );
@@ -113,35 +112,38 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let server_task = tokio::spawn(bind(server_socket.compat())); let server_task = tokio::spawn(bind(server_socket.compat()));
let mut tls_config_builder = TlsConfig::builder(); let prover = Prover::new(ProverConfig::builder().build().unwrap())
tls_config_builder.root_store(RootCertStore { .commit(
roots: vec![CertificateDer(CA_CERT_DER.to_vec())], TlsCommitConfig::builder()
}); .protocol(
MpcTlsConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_sent_records(MAX_SENT_RECORDS)
.max_recv_data(MAX_RECV_DATA)
.max_recv_records_online(MAX_RECV_RECORDS)
.build()
.unwrap(),
)
.build()
.unwrap(),
verifier_socket.compat(),
)
.await
.unwrap();
let tls_config = tls_config_builder.build().unwrap(); let (mut tls_connection, prover_fut) = prover
.connect(
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()); TlsClientConfig::builder()
let prover = Prover::new( .server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
ProverConfig::builder() .root_store(RootCertStore {
.server_name(server_name) roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
.tls_config(tls_config) })
.protocol_config( .build()
ProtocolConfig::builder() .unwrap(),
.max_sent_data(MAX_SENT_DATA) client_socket.compat(),
.max_sent_records(MAX_SENT_RECORDS) )
.max_recv_data(MAX_RECV_DATA) .await
.max_recv_records_online(MAX_RECV_RECORDS) .unwrap();
.build()
.unwrap(),
)
.build()
.unwrap(),
)
.setup(verifier_socket.compat())
.await
.unwrap();
let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap();
let prover_task = tokio::spawn(prover_fut); let prover_task = tokio::spawn(prover_fut);
tls_connection tls_connection
@@ -214,7 +216,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
); );
let verifier = verifier let verifier = verifier
.setup(socket.compat()) .commit(socket.compat())
.await .await
.unwrap() .unwrap()
.accept() .accept()

View File

@@ -21,7 +21,7 @@ no-bundler = ["web-spawn/no-bundler"]
[dependencies] [dependencies]
tlsn-core = { workspace = true } tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web"] } tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true } tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true } tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true } tlsn-tls-core = { workspace = true }

View File

@@ -1,11 +1,6 @@
use crate::types::NetworkSetting; use crate::types::NetworkSetting;
use serde::Deserialize; use serde::Deserialize;
use tlsn::{
config::{CertificateDer, PrivateKeyDer, ProtocolConfig},
connection::ServerName,
};
use tsify_next::Tsify; use tsify_next::Tsify;
use wasm_bindgen::JsError;
#[derive(Debug, Tsify, Deserialize)] #[derive(Debug, Tsify, Deserialize)]
#[tsify(from_wasm_abi)] #[tsify(from_wasm_abi)]
@@ -20,66 +15,3 @@ pub struct ProverConfig {
pub network: NetworkSetting, pub network: NetworkSetting,
pub client_auth: Option<(Vec<Vec<u8>>, Vec<u8>)>, pub client_auth: Option<(Vec<Vec<u8>>, Vec<u8>)>,
} }
impl TryFrom<ProverConfig> for tlsn::prover::ProverConfig {
type Error = JsError;
fn try_from(value: ProverConfig) -> Result<Self, Self::Error> {
let mut builder = ProtocolConfig::builder();
builder.max_sent_data(value.max_sent_data);
builder.max_recv_data(value.max_recv_data);
if let Some(value) = value.max_recv_data_online {
builder.max_recv_data_online(value);
}
if let Some(value) = value.max_sent_records {
builder.max_sent_records(value);
}
if let Some(value) = value.max_recv_records_online {
builder.max_recv_records_online(value);
}
if let Some(value) = value.defer_decryption_from_start {
builder.defer_decryption_from_start(value);
}
builder.network(value.network.into());
let protocol_config = builder.build().unwrap();
let mut builder = tlsn::prover::TlsConfig::builder();
if let Some((certs, key)) = value.client_auth {
let certs = certs
.into_iter()
.map(|cert| {
// Try to parse as PEM-encoded, otherwise assume DER.
if let Ok(cert) = CertificateDer::from_pem_slice(&cert) {
cert
} else {
CertificateDer(cert)
}
})
.collect();
let key = PrivateKeyDer(key);
builder.client_auth((certs, key));
}
let tls_config = builder.build().unwrap();
let server_name = ServerName::Dns(
value
.server_name
.try_into()
.map_err(|_| JsError::new("invalid server name"))?,
);
let mut builder = tlsn::prover::ProverConfig::builder();
builder
.server_name(server_name)
.protocol_config(protocol_config)
.tls_config(tls_config);
Ok(builder.build().unwrap())
}
}

View File

@@ -7,7 +7,16 @@ use futures::TryFutureExt;
use http_body_util::{BodyExt, Full}; use http_body_util::{BodyExt, Full};
use hyper::body::Bytes; use hyper::body::Bytes;
use tls_client_async::TlsConnection; use tls_client_async::TlsConnection;
use tlsn::prover::{state, ProveConfig, Prover}; use tlsn::{
config::{
prove::ProveConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
connection::ServerName,
prover::{state, Prover},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tracing::info; use tracing::info;
use wasm_bindgen::{prelude::*, JsError}; use wasm_bindgen::{prelude::*, JsError};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
@@ -19,6 +28,7 @@ type Result<T> = std::result::Result<T, JsError>;
#[wasm_bindgen(js_name = Prover)] #[wasm_bindgen(js_name = Prover)]
pub struct JsProver { pub struct JsProver {
config: ProverConfig,
state: State, state: State,
} }
@@ -26,7 +36,7 @@ pub struct JsProver {
#[derive_err(Debug)] #[derive_err(Debug)]
enum State { enum State {
Initialized(Prover<state::Initialized>), Initialized(Prover<state::Initialized>),
Setup(Prover<state::Setup>), CommitAccepted(Prover<state::CommitAccepted>),
Committed(Prover<state::Committed>), Committed(Prover<state::Committed>),
Complete, Complete,
Error, Error,
@@ -43,7 +53,10 @@ impl JsProver {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new(config: ProverConfig) -> Result<JsProver> { pub fn new(config: ProverConfig) -> Result<JsProver> {
Ok(JsProver { Ok(JsProver {
state: State::Initialized(Prover::new(config.try_into()?)), config,
state: State::Initialized(Prover::new(
tlsn::config::prover::ProverConfig::builder().build()?,
)),
}) })
} }
@@ -54,15 +67,41 @@ impl JsProver {
pub async fn setup(&mut self, verifier_url: &str) -> Result<()> { pub async fn setup(&mut self, verifier_url: &str) -> Result<()> {
let prover = self.state.take().try_into_initialized()?; let prover = self.state.take().try_into_initialized()?;
let config = TlsCommitConfig::builder()
.protocol({
let mut builder = MpcTlsConfig::builder()
.max_sent_data(self.config.max_sent_data)
.max_recv_data(self.config.max_recv_data);
if let Some(value) = self.config.max_recv_data_online {
builder = builder.max_recv_data_online(value);
}
if let Some(value) = self.config.max_sent_records {
builder = builder.max_sent_records(value);
}
if let Some(value) = self.config.max_recv_records_online {
builder = builder.max_recv_records_online(value);
}
if let Some(value) = self.config.defer_decryption_from_start {
builder = builder.defer_decryption_from_start(value);
}
builder.network(self.config.network.into()).build()
}?)
.build()?;
info!("connecting to verifier"); info!("connecting to verifier");
let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?; let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?;
info!("connected to verifier"); info!("connected to verifier");
let prover = prover.setup(verifier_conn.into_io()).await?; let prover = prover.commit(config, verifier_conn.into_io()).await?;
self.state = State::Setup(prover); self.state = State::CommitAccepted(prover);
Ok(()) Ok(())
} }
@@ -73,7 +112,35 @@ impl JsProver {
ws_proxy_url: &str, ws_proxy_url: &str,
request: HttpRequest, request: HttpRequest,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let prover = self.state.take().try_into_setup()?; let prover = self.state.take().try_into_commit_accepted()?;
let mut builder = TlsClientConfig::builder()
.server_name(ServerName::Dns(
self.config
.server_name
.clone()
.try_into()
.map_err(|_| JsError::new("invalid server name"))?,
))
.root_store(RootCertStore::mozilla());
if let Some((certs, key)) = self.config.client_auth.clone() {
let certs = certs
.into_iter()
.map(|cert| {
// Try to parse as PEM-encoded, otherwise assume DER.
if let Ok(cert) = CertificateDer::from_pem_slice(&cert) {
cert
} else {
CertificateDer(cert)
}
})
.collect();
let key = PrivateKeyDer(key);
builder = builder.client_auth((certs, key));
}
let config = builder.build()?;
info!("connecting to server"); info!("connecting to server");
@@ -81,7 +148,7 @@ impl JsProver {
info!("connected to server"); info!("connected to server");
let (tls_conn, prover_fut) = prover.connect(server_conn.into_io()).await?; let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?;
info!("sending request"); info!("sending request");
@@ -137,14 +204,6 @@ impl JsProver {
} }
} }
impl From<Prover<state::Initialized>> for JsProver {
fn from(value: Prover<state::Initialized>) -> Self {
JsProver {
state: State::Initialized(value),
}
}
}
async fn send_request(conn: TlsConnection, request: HttpRequest) -> Result<HttpResponse> { async fn send_request(conn: TlsConnection, request: HttpRequest) -> Result<HttpResponse> {
let conn = FuturesIo::new(conn); let conn = FuturesIo::new(conn);
let request = hyper::Request::<Full<Bytes>>::try_from(request)?; let request = hyper::Request::<Full<Bytes>>::try_from(request)?;

View File

@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
fn from(value: tlsn::transcript::PartialTranscript) -> Self { fn from(value: tlsn::transcript::PartialTranscript) -> Self {
Self { Self {
sent: value.sent_unsafe().to_vec(), sent: value.sent_unsafe().to_vec(),
sent_authed: value.sent_authed().iter_ranges().collect(), sent_authed: value.sent_authed().iter().collect(),
recv: value.received_unsafe().to_vec(), recv: value.received_unsafe().to_vec(),
recv_authed: value.received_authed().iter_ranges().collect(), recv_authed: value.received_authed().iter().collect(),
} }
} }
} }
@@ -181,7 +181,7 @@ pub struct VerifierOutput {
pub transcript: Option<PartialTranscript>, pub transcript: Option<PartialTranscript>,
} }
#[derive(Debug, Tsify, Deserialize)] #[derive(Debug, Clone, Copy, Tsify, Deserialize)]
#[tsify(from_wasm_abi)] #[tsify(from_wasm_abi)]
pub enum NetworkSetting { pub enum NetworkSetting {
/// Prefers a bandwidth-heavy protocol. /// Prefers a bandwidth-heavy protocol.
@@ -190,7 +190,7 @@ pub enum NetworkSetting {
Latency, Latency,
} }
impl From<NetworkSetting> for tlsn::config::NetworkSetting { impl From<NetworkSetting> for tlsn::config::tls_commit::mpc::NetworkSetting {
fn from(value: NetworkSetting) -> Self { fn from(value: NetworkSetting) -> Self {
match value { match value {
NetworkSetting::Bandwidth => Self::Bandwidth, NetworkSetting::Bandwidth => Self::Bandwidth,

View File

@@ -3,10 +3,12 @@ mod config;
pub use config::VerifierConfig; pub use config::VerifierConfig;
use enum_try_as_inner::EnumTryAsInner; use enum_try_as_inner::EnumTryAsInner;
use tls_core::msgs::enums::ContentType;
use tlsn::{ use tlsn::{
config::tls_commit::TlsCommitProtocolConfig,
connection::{ConnectionInfo, ServerName, TranscriptLength}, connection::{ConnectionInfo, ServerName, TranscriptLength},
transcript::ContentType,
verifier::{state, Verifier}, verifier::{state, Verifier},
webpki::RootCertStore,
}; };
use tracing::info; use tracing::info;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
@@ -47,7 +49,10 @@ impl State {
impl JsVerifier { impl JsVerifier {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new(config: VerifierConfig) -> JsVerifier { pub fn new(config: VerifierConfig) -> JsVerifier {
let tlsn_config = tlsn::verifier::VerifierConfig::builder().build().unwrap(); let tlsn_config = tlsn::config::verifier::VerifierConfig::builder()
.root_store(RootCertStore::mozilla())
.build()
.unwrap();
JsVerifier { JsVerifier {
state: State::Initialized(Verifier::new(tlsn_config)), state: State::Initialized(Verifier::new(tlsn_config)),
config, config,
@@ -73,16 +78,20 @@ impl JsVerifier {
pub async fn verify(&mut self) -> Result<VerifierOutput> { pub async fn verify(&mut self) -> Result<VerifierOutput> {
let (verifier, prover_conn) = self.state.take().try_into_connected()?; let (verifier, prover_conn) = self.state.take().try_into_connected()?;
let verifier = verifier.setup(prover_conn.into_io()).await?; let verifier = verifier.commit(prover_conn.into_io()).await?;
let config = verifier.config(); let request = verifier.request();
let reject = if config.max_sent_data() > self.config.max_sent_data { let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol() else {
unimplemented!("only MPC protocol is supported");
};
let reject = if mpc_tls_config.max_sent_data() > self.config.max_sent_data {
Some("max_sent_data is too large") Some("max_sent_data is too large")
} else if config.max_recv_data() > self.config.max_recv_data { } else if mpc_tls_config.max_recv_data() > self.config.max_recv_data {
Some("max_recv_data is too large") Some("max_recv_data is too large")
} else if config.max_sent_records() > self.config.max_sent_records { } else if mpc_tls_config.max_sent_records() > self.config.max_sent_records {
Some("max_sent_records is too large") Some("max_sent_records is too large")
} else if config.max_recv_records_online() > self.config.max_recv_records_online { } else if mpc_tls_config.max_recv_records_online() > self.config.max_recv_records_online {
Some("max_recv_records_online is too large") Some("max_recv_records_online is too large")
} else { } else {
None None