Compare commits

...

24 Commits

Author SHA1 Message Date
dan
c1ee91e63b verify predicate output 2025-12-12 09:21:44 +02:00
dan
57f85dcdf2 add missing example 2025-12-11 16:37:41 +02:00
dan
727dff3ac9 add missing file 2025-12-11 16:33:09 +02:00
dan
830d0ab0d5 feat: add support for proving predicates 2025-12-11 16:25:49 +02:00
dan
4fe5c1defd feat(harness): add reveal_all config (#1063) 2025-12-09 12:01:39 +00:00
dan
0e8e547300 chore: adapt for rangeset 0.4 (#1058) 2025-12-09 11:36:13 +00:00
dan
22cc88907a chore: bump mpz (#1057) 2025-12-04 10:27:43 +00:00
Hendrik Eeckhaut
cec4756e0e ci: set GITHUB_TOKEN env 2025-11-28 14:33:19 +01:00
Hendrik Eeckhaut
0919e1f2b3 clippy: allow deprecated aead::generic_array 2025-11-28 14:33:19 +01:00
Hendrik Eeckhaut
43b9f57e1f build: update Rust to version 1.91.1 2025-11-26 16:53:08 +01:00
dan
c51331d63d test: use ideal vm for testing (#1049) 2025-11-07 12:56:09 +00:00
dan
3905d9351c chore: clean up deps (#1048) 2025-11-07 10:36:41 +00:00
dan
f8a67bc8e7 feat(core): support proving keccak256 commitments (#1046) 2025-11-07 09:18:44 +00:00
dan
952a7011bf feat(cipher): use AES pre/post key schedule circuits (#1042) 2025-11-07 09:08:08 +00:00
Ram
0673818e4e chore: fix links to key exchange doc page (#1045) 2025-11-04 23:07:58 +01:00
dan
a5749d81f1 fix(attestation): verify sig during validation (#1037) 2025-10-30 07:59:57 +00:00
sinu.eth
f2e119bb66 refactor: move and rewrite configuration (#1034)
* refactor: move and rewrite configuration

* fix wasm
2025-10-27 11:47:42 -07:00
Hendrik Eeckhaut
271ac3771e Fix example (#1033)
* fix: provide encoder secret to attestation
* Add missing entry in example's README file
2025-10-24 10:33:32 +02:00
Benjamin Martinez Picech
f69dd7a239 refactor(tlsn-core): redeclaration of content type into core (#1026)
* redeclaration of content type into core

* fix compilation error

* comment removal

* Lint and format fixes

* fix wasm build

* Unknown content type

* format fix
2025-10-23 15:47:53 +02:00
sinu.eth
79f5160cae feat(tlsn): insecure mode (#1031) 2025-10-22 10:18:11 -07:00
sinu.eth
5fef2af698 fix(example): close prover (#1025) 2025-10-17 10:39:08 -07:00
sinu.eth
5b2083e211 refactor(tlsn): invert control of config validation (#1023)
* refactor(tlsn): invert control of config validation

* clippy
2025-10-17 10:19:02 -07:00
sinu.eth
d26bb02d2e chore: update to alpha.14-pre (#1022) 2025-10-15 11:11:43 -07:00
sinu
a766b64184 ci: add manual trigger to main update 2025-10-15 10:11:01 -07:00
102 changed files with 4327 additions and 2598 deletions

View File

@@ -21,7 +21,8 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.90.0
RUST_VERSION: 1.91.1
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
clippy:

View File

@@ -6,7 +6,7 @@ on:
tag:
description: 'Tag to publish to NPM'
required: true
default: 'v0.1.0-alpha.13'
default: 'v0.1.0-alpha.14-pre'
jobs:
release:

View File

@@ -1,6 +1,7 @@
name: Fast-forward main branch to published release tag
on:
workflow_dispatch:
release:
types: [published]

1056
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,26 +66,27 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-predicate = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
rangeset = { version = "0.2" }
rangeset = { version = "0.4" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
aead = { version = "0.4" }
aes = { version = "0.8" }

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-attestation"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2024"
[features]
@@ -9,7 +9,7 @@ fixtures = ["tlsn-core/fixtures", "dep:tlsn-data-fixtures"]
[dependencies]
tlsn-tls-core = { workspace = true }
tlsn-core = { workspace = true }
tlsn-core = { workspace = true, features = ["mozilla-certs"] }
tlsn-data-fixtures = { workspace = true, optional = true }
bcs = { workspace = true }

View File

@@ -1,5 +1,4 @@
//! Attestation fixtures.
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2},
fixtures::ConnectionFixture,
@@ -13,7 +12,10 @@ use tlsn_core::{
use crate::{
Attestation, AttestationConfig, CryptoProvider, Extension,
request::{Request, RequestConfig},
signing::SignatureAlgId,
signing::{
KeyAlgId, SignatureAlgId, SignatureVerifier, SignatureVerifierProvider, Signer,
SignerProvider,
},
};
/// A Request fixture used for testing.
@@ -102,7 +104,8 @@ pub fn attestation_fixture(
let mut provider = CryptoProvider::default();
match signature_alg {
SignatureAlgId::SECP256K1 => provider.signer.set_secp256k1(&[42u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[42u8; 32]).unwrap(),
SignatureAlgId::SECP256K1ETH => provider.signer.set_secp256k1eth(&[43u8; 32]).unwrap(),
SignatureAlgId::SECP256R1 => provider.signer.set_secp256r1(&[44u8; 32]).unwrap(),
_ => unimplemented!(),
};
@@ -122,3 +125,68 @@ pub fn attestation_fixture(
attestation_builder.build(&provider).unwrap()
}
/// Returns a crypto provider which supports only a custom signature alg.
pub fn custom_provider_fixture() -> CryptoProvider {
const CUSTOM_SIG_ALG_ID: SignatureAlgId = SignatureAlgId::new(128);
// A dummy signer.
struct DummySigner {}
impl Signer for DummySigner {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn sign(
&self,
msg: &[u8],
) -> Result<crate::signing::Signature, crate::signing::SignatureError> {
Ok(crate::signing::Signature {
alg: CUSTOM_SIG_ALG_ID,
data: msg.to_vec(),
})
}
fn verifying_key(&self) -> crate::signing::VerifyingKey {
crate::signing::VerifyingKey {
alg: KeyAlgId::new(128),
data: vec![1, 2, 3, 4],
}
}
}
// A dummy verifier.
struct DummyVerifier {}
impl SignatureVerifier for DummyVerifier {
fn alg_id(&self) -> SignatureAlgId {
CUSTOM_SIG_ALG_ID
}
fn verify(
&self,
_key: &crate::signing::VerifyingKey,
msg: &[u8],
sig: &[u8],
) -> Result<(), crate::signing::SignatureError> {
if msg == sig {
Ok(())
} else {
Err(crate::signing::SignatureError::from_str(
"invalid signature",
))
}
}
}
let mut provider = CryptoProvider::default();
let mut signer_provider = SignerProvider::default();
signer_provider.set_signer(Box::new(DummySigner {}));
provider.signer = signer_provider;
let mut verifier_provider = SignatureVerifierProvider::empty();
verifier_provider.set_verifier(Box::new(DummyVerifier {}));
provider.signature = verifier_provider;
provider
}

View File

@@ -20,7 +20,10 @@ use serde::{Deserialize, Serialize};
use tlsn_core::hash::HashAlgId;
use crate::{Attestation, Extension, connection::ServerCertCommitment, signing::SignatureAlgId};
use crate::{
Attestation, CryptoProvider, Extension, connection::ServerCertCommitment,
serialize::CanonicalSerialize, signing::SignatureAlgId,
};
pub use builder::{RequestBuilder, RequestBuilderError};
pub use config::{RequestConfig, RequestConfigBuilder, RequestConfigBuilderError};
@@ -41,44 +44,102 @@ impl Request {
}
/// Validates the content of the attestation against this request.
pub fn validate(&self, attestation: &Attestation) -> Result<(), InconsistentAttestation> {
pub fn validate(
&self,
attestation: &Attestation,
provider: &CryptoProvider,
) -> Result<(), AttestationValidationError> {
if attestation.signature.alg != self.signature_alg {
return Err(InconsistentAttestation(format!(
return Err(AttestationValidationError::inconsistent(format!(
"signature algorithm: expected {:?}, got {:?}",
self.signature_alg, attestation.signature.alg
)));
}
if attestation.header.root.alg != self.hash_alg {
return Err(InconsistentAttestation(format!(
return Err(AttestationValidationError::inconsistent(format!(
"hash algorithm: expected {:?}, got {:?}",
self.hash_alg, attestation.header.root.alg
)));
}
if attestation.body.cert_commitment() != &self.server_cert_commitment {
return Err(InconsistentAttestation(
"server certificate commitment does not match".to_string(),
return Err(AttestationValidationError::inconsistent(
"server certificate commitment does not match",
));
}
// TODO: improve the O(M*N) complexity of this check.
for extension in &self.extensions {
if !attestation.body.extensions().any(|e| e == extension) {
return Err(InconsistentAttestation(
"extension is missing from the attestation".to_string(),
return Err(AttestationValidationError::inconsistent(
"extension is missing from the attestation",
));
}
}
let verifier = provider
.signature
.get(&attestation.signature.alg)
.map_err(|_| {
AttestationValidationError::provider(format!(
"provider not configured for signature algorithm id {:?}",
attestation.signature.alg,
))
})?;
verifier
.verify(
&attestation.body.verifying_key.data,
&CanonicalSerialize::serialize(&attestation.header),
&attestation.signature.data,
)
.map_err(|_| {
AttestationValidationError::inconsistent("failed to verify the signature")
})?;
Ok(())
}
}
/// Error for [`Request::validate`].
#[derive(Debug, thiserror::Error)]
#[error("inconsistent attestation: {0}")]
pub struct InconsistentAttestation(String);
#[error("attestation validation error: {kind}: {message}")]
pub struct AttestationValidationError {
kind: ErrorKind,
message: String,
}
impl AttestationValidationError {
fn inconsistent(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Inconsistent,
message: msg.into(),
}
}
fn provider(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Provider,
message: msg.into(),
}
}
}
#[derive(Debug)]
enum ErrorKind {
Inconsistent,
Provider,
}
impl std::fmt::Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorKind::Inconsistent => write!(f, "inconsistent"),
ErrorKind::Provider => write!(f, "provider"),
}
}
}
#[cfg(test)]
mod test {
@@ -93,7 +154,8 @@ mod test {
use crate::{
CryptoProvider,
connection::ServerCertOpening,
fixtures::{RequestFixture, attestation_fixture, request_fixture},
fixtures::{RequestFixture, attestation_fixture, custom_provider_fixture, request_fixture},
request::{AttestationValidationError, ErrorKind},
signing::SignatureAlgId,
};
@@ -113,7 +175,9 @@ mod test {
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
assert!(request.validate(&attestation).is_ok())
let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_ok())
}
#[test]
@@ -134,7 +198,9 @@ mod test {
request.signature_alg = SignatureAlgId::SECP256R1;
let res = request.validate(&attestation);
let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
assert!(res.is_err());
}
@@ -156,7 +222,9 @@ mod test {
request.hash_alg = HashAlgId::SHA256;
let res = request.validate(&attestation);
let provider = CryptoProvider::default();
let res = request.validate(&attestation, &provider);
assert!(res.is_err())
}
@@ -184,11 +252,62 @@ mod test {
});
let opening = ServerCertOpening::new(server_cert_data);
let crypto_provider = CryptoProvider::default();
let provider = CryptoProvider::default();
request.server_cert_commitment =
opening.commit(crypto_provider.hash.get(&HashAlgId::BLAKE3).unwrap());
opening.commit(provider.hash.get(&HashAlgId::BLAKE3).unwrap());
let res = request.validate(&attestation);
let res = request.validate(&attestation, &provider);
assert!(res.is_err())
}
#[test]
fn test_wrong_sig() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let mut attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
// Corrupt the signature.
attestation.signature.data[1] = attestation.signature.data[1].wrapping_add(1);
let provider = CryptoProvider::default();
assert!(request.validate(&attestation, &provider).is_err())
}
#[test]
fn test_wrong_provider() {
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
let connection = ConnectionFixture::tlsnotary(transcript.length());
let RequestFixture { request, .. } = request_fixture(
transcript,
encoding_provider(GET_WITH_HEADER, OK_JSON),
connection.clone(),
Blake3::default(),
Vec::new(),
);
let attestation =
attestation_fixture(request.clone(), connection, SignatureAlgId::SECP256K1, &[]);
let provider = custom_provider_fixture();
assert!(matches!(
request.validate(&attestation, &provider),
Err(AttestationValidationError {
kind: ErrorKind::Provider,
..
})
))
}
}

View File

@@ -202,6 +202,14 @@ impl SignatureVerifierProvider {
.map(|s| &**s)
.ok_or(UnknownSignatureAlgId(*alg))
}
/// Returns am empty provider.
#[cfg(any(test, feature = "fixtures"))]
pub fn empty() -> Self {
Self {
verifiers: HashMap::default(),
}
}
}
/// Signature verifier.
@@ -229,6 +237,14 @@ impl_domain_separator!(VerifyingKey);
#[error("signature verification failed: {0}")]
pub struct SignatureError(String);
impl SignatureError {
/// Creates a new error with the given message.
#[allow(clippy::should_implement_trait)]
pub fn from_str(msg: &str) -> Self {
Self(msg.to_string())
}
}
/// A signature.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Signature {

View File

@@ -101,7 +101,7 @@ fn test_api() {
let attestation = attestation_builder.build(&provider).unwrap();
// Prover validates the attestation is consistent with its request.
request.validate(&attestation).unwrap();
request.validate(&attestation, &provider).unwrap();
let mut transcript_proof_builder = secrets.transcript_proof_builder();

View File

@@ -5,7 +5,7 @@ description = "This crate provides implementations of ciphers for two parties"
keywords = ["tls", "mpc", "2pc", "aes"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -15,7 +15,7 @@ workspace = true
name = "cipher"
[dependencies]
mpz-circuits = { workspace = true }
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
@@ -24,11 +24,9 @@ thiserror = { workspace = true }
aes = { workspace = true }
[dev-dependencies]
mpz-garble = { workspace = true }
mpz-common = { workspace = true }
mpz-ot = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }

View File

@@ -2,7 +2,7 @@
use crate::{Cipher, CtrBlock, Keystream};
use async_trait::async_trait;
use mpz_circuits::circuits::AES128;
use mpz_circuits::{AES128_KS, AES128_POST_KS};
use mpz_memory_core::binary::{Binary, U8};
use mpz_vm_core::{prelude::*, Call, Vm};
use std::fmt::Debug;
@@ -12,13 +12,35 @@ mod error;
pub use error::AesError;
use error::ErrorKind;
/// AES key schedule: 11 round keys, 16 bytes each.
type KeySchedule = Array<U8, 176>;
/// Computes AES-128.
#[derive(Default, Debug)]
pub struct Aes128 {
key: Option<Array<U8, 16>>,
key_schedule: Option<KeySchedule>,
iv: Option<Array<U8, 4>>,
}
impl Aes128 {
// Allocates key schedule.
//
// Expects the key to be already set.
fn alloc_key_schedule(&self, vm: &mut dyn Vm<Binary>) -> Result<KeySchedule, AesError> {
let ks: KeySchedule = vm
.call(
Call::builder(AES128_KS.clone())
.arg(self.key.expect("key is set"))
.build()
.expect("call should be valid"),
)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
Ok(ks)
}
}
#[async_trait]
impl Cipher for Aes128 {
type Error = AesError;
@@ -45,18 +67,22 @@ impl Cipher for Aes128 {
}
fn alloc_block(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
input: Array<U8, 16>,
) -> Result<Self::Block, Self::Error> {
let key = self
.key
self.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128.clone())
.arg(key)
Call::builder(AES128_POST_KS.clone())
.arg(ks)
.arg(input)
.build()
.expect("call should be valid"),
@@ -67,11 +93,10 @@ impl Cipher for Aes128 {
}
fn alloc_ctr_block(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self
.key
self.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
@@ -89,10 +114,15 @@ impl Cipher for Aes128 {
vm.mark_public(counter)
.map_err(|err| AesError::new(ErrorKind::Vm, err))?;
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128.clone())
.arg(key)
Call::builder(AES128_POST_KS.clone())
.arg(ks)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
@@ -109,12 +139,11 @@ impl Cipher for Aes128 {
}
fn alloc_keystream(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error> {
let key = self
.key
self.key
.ok_or_else(|| AesError::new(ErrorKind::Key, "key not set"))?;
let iv = self
.iv
@@ -143,10 +172,15 @@ impl Cipher for Aes128 {
let blocks = inputs
.into_iter()
.map(|(explicit_nonce, counter)| {
if self.key_schedule.is_none() {
self.key_schedule = Some(self.alloc_key_schedule(vm)?);
}
let ks = *self.key_schedule.as_ref().expect("key schedule was set");
let output = vm
.call(
Call::builder(AES128.clone())
.arg(key)
Call::builder(AES128_POST_KS.clone())
.arg(ks)
.arg(iv)
.arg(explicit_nonce)
.arg(counter)
@@ -172,15 +206,12 @@ mod tests {
use super::*;
use crate::Cipher;
use mpz_common::context::test_st_context;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::{
binary::{Binary, U8},
correlated::Delta,
Array, MemoryExt, Vector, ViewExt,
};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_vm_core::{Execute, Vm};
use rand::{rngs::StdRng, SeedableRng};
#[tokio::test]
async fn test_aes_ctr() {
@@ -190,10 +221,11 @@ mod tests {
let start_counter = 3u32;
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let aes_gen = setup_ctr(key, iv, &mut gen);
let aes_ev = setup_ctr(key, iv, &mut ev);
let mut aes_gen = setup_ctr(key, iv, &mut gen);
let mut aes_ev = setup_ctr(key, iv, &mut ev);
let msg = vec![42u8; 128];
@@ -252,10 +284,11 @@ mod tests {
let input = [5_u8; 16];
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let aes_gen = setup_block(key, &mut gen);
let aes_ev = setup_block(key, &mut ev);
let mut aes_gen = setup_block(key, &mut gen);
let mut aes_ev = setup_block(key, &mut ev);
let block_ref_gen: Array<U8, 16> = gen.alloc().unwrap();
gen.mark_public(block_ref_gen).unwrap();
@@ -294,18 +327,6 @@ mod tests {
assert_eq!(ciphertext_gen, expected);
}
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
fn setup_ctr(key: [u8; 16], iv: [u8; 4], vm: &mut dyn Vm<Binary>) -> Aes128 {
let key_ref: Array<U8, 16> = vm.alloc().unwrap();
vm.mark_public(key_ref).unwrap();

View File

@@ -55,7 +55,7 @@ pub trait Cipher {
/// Allocates a single block in ECB mode.
fn alloc_block(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
input: Self::Block,
) -> Result<Self::Block, Self::Error>;
@@ -63,7 +63,7 @@ pub trait Cipher {
/// Allocates a single block in counter mode.
#[allow(clippy::type_complexity)]
fn alloc_ctr_block(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
) -> Result<CtrBlock<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;
@@ -75,7 +75,7 @@ pub trait Cipher {
/// * `len` - Length of the stream in bytes.
#[allow(clippy::type_complexity)]
fn alloc_keystream(
&self,
&mut self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<Keystream<Self::Nonce, Self::Counter, Self::Block>, Self::Error>;

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-deap"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -19,11 +19,8 @@ futures = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-ot = { workspace = true }
mpz-zk = { workspace = true }
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
rand06-compat = { workspace = true }

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{Difference, RangeSet, UnionMut};
use rangeset::{ops::Set, set::RangeSet};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError;
@@ -210,10 +210,12 @@ where
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() {
for input in input_minus_follower {
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
@@ -266,7 +268,7 @@ where
mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -282,7 +284,7 @@ where
mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {
@@ -382,37 +384,27 @@ enum ErrorRepr {
#[cfg(test)]
mod tests {
use mpz_circuits::circuits::AES128;
use mpz_circuits::AES128;
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
memory::{binary::U8, Array},
prelude::*,
};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
#[tokio::test]
async fn test_deap() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let (ct_leader, ct_follower) = futures::join!(
async {
@@ -478,21 +470,15 @@ mod tests {
#[tokio::test]
async fn test_deap_desync_memory() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
// Desynchronize the memories.
let _ = leader.zk().alloc_raw(1).unwrap();
@@ -564,21 +550,15 @@ mod tests {
// detection by the follower.
#[tokio::test]
async fn test_malicious() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(ProverConfig::default(), rcot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta_zk, rcot_send);
let leader_mpc = IdealVm::new();
let leader_zk = IdealVm::new();
let follower_mpc = IdealVm::new();
let follower_zk = IdealVm::new();
let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);
let mut leader = Deap::new(Role::Leader, leader_mpc, leader_zk);
let mut follower = Deap::new(Role::Follower, follower_mpc, follower_zk);
let (_, follower_res) = futures::join!(
async {

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::Subset;
use rangeset::ops::Set;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]

View File

@@ -5,7 +5,7 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF"
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -20,14 +20,13 @@ mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-hash = { workspace = true }
sha2 = { workspace = true, features = ["compress"] }
thiserror = { workspace = true }
tracing = { workspace = true }
sha2 = { workspace = true }
[dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ideal-vm = { workspace = true }
criterion = { workspace = true, features = ["async_tokio"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -4,14 +4,12 @@ use criterion::{criterion_group, criterion_main, Criterion};
use hmac_sha256::{Mode, MpcPrf};
use mpz_common::context::test_mt_context;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, correlated::Delta, Array},
memory::{binary::U8, Array},
prelude::*,
Execute,
};
use rand::{rngs::StdRng, SeedableRng};
#[allow(clippy::unit_arg)]
fn criterion_benchmark(c: &mut Criterion) {
@@ -29,8 +27,6 @@ criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
async fn prf(mode: Mode) {
let mut rng = StdRng::seed_from_u64(0);
let pms = [42u8; 32];
let client_random = [69u8; 32];
let server_random: [u8; 32] = [96u8; 32];
@@ -39,11 +35,8 @@ async fn prf(mode: Mode) {
let mut leader_ctx = leader_exec.new_context().await.unwrap();
let mut follower_ctx = follower_exec.new_context().await.unwrap();
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
let mut follower_vm = Evaluator::new(ot_recv);
let mut leader_vm = IdealVm::new();
let mut follower_vm = IdealVm::new();
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
leader_vm.mark_public(leader_pms).unwrap();

View File

@@ -54,10 +54,11 @@ mod tests {
use crate::{
hmac::hmac_sha256,
sha256, state_to_bytes,
test_utils::{compute_inner_local, compute_outer_partial, mock_vm},
test_utils::{compute_inner_local, compute_outer_partial},
};
use mpz_common::context::test_st_context;
use mpz_hash::sha256::Sha256;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{
binary::{U32, U8},
@@ -83,7 +84,8 @@ mod tests {
#[tokio::test]
async fn test_hmac_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let (inputs, references) = test_fixtures();
for (input, &reference) in inputs.iter().zip(references.iter()) {

View File

@@ -72,10 +72,11 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
#[cfg(test)]
mod tests {
use crate::{
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
test_utils::{prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
Mode, MpcPrf, SessionKeys,
};
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
@@ -123,7 +124,8 @@ mod tests {
// Set up vm and prf
let (mut ctx_a, mut ctx_b) = test_st_context(128);
let (mut leader, mut follower) = mock_vm();
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
leader.mark_public(leader_pms).unwrap();

View File

@@ -339,8 +339,9 @@ fn gen_merge_circ(size: usize) -> Arc<Circuit> {
#[cfg(test)]
mod tests {
use crate::{prf::merge_outputs, test_utils::mock_vm};
use crate::prf::merge_outputs;
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
@@ -349,7 +350,8 @@ mod tests {
#[tokio::test]
async fn test_merge_outputs() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);

View File

@@ -137,10 +137,11 @@ impl Prf {
mod tests {
use crate::{
prf::{compute_partial, function::Prf},
test_utils::{mock_vm, phash},
test_utils::phash,
Mode,
};
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::{
memory::{binary::U8, Array, MemoryExt, ViewExt},
Execute,
@@ -166,7 +167,8 @@ mod tests {
let mut rng = ThreadRng::default();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut leader, mut follower) = mock_vm();
let mut leader = IdealVm::new();
let mut follower = IdealVm::new();
let key: [u8; 32] = rng.random();
let start_seed: Vec<u8> = vec![42; 64];

View File

@@ -1,25 +1,10 @@
use crate::{sha256, state_to_bytes};
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_vm_core::memory::correlated::Delta;
use rand::{rngs::StdRng, Rng, SeedableRng};
pub(crate) const SHA256_IV: [u32; 8] = [
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
];
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
let mut label_start_seed = b"master secret".to_vec();
label_start_seed.extend_from_slice(&client_random);

View File

@@ -5,7 +5,7 @@ description = "Implementation of the 3-party key-exchange protocol"
keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -40,6 +40,7 @@ tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-garble = { workspace = true }
mpz-ideal-vm = { workspace = true }
rand_core = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }

View File

@@ -459,9 +459,7 @@ mod tests {
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_fields::UniformRand;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_memory_core::correlated::Delta;
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
use mpz_ideal_vm::IdealVm;
use mpz_share_conversion::ideal::{
ideal_share_convert, IdealShareConvertReceiver, IdealShareConvertSender,
};
@@ -484,7 +482,8 @@ mod tests {
async fn test_key_exchange() {
let mut rng = StdRng::seed_from_u64(0).compat();
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let leader_private_key = SecretKey::random(&mut rng);
let follower_private_key = SecretKey::random(&mut rng);
@@ -625,7 +624,8 @@ mod tests {
async fn test_malicious_key_exchange(#[case] malicious: Malicious) {
let mut rng = StdRng::seed_from_u64(0);
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (mut gen, mut ev) = mock_vm();
let mut gen = IdealVm::new();
let mut ev = IdealVm::new();
let leader_private_key = SecretKey::random(&mut rng.compat_by_ref());
let follower_private_key = SecretKey::random(&mut rng.compat_by_ref());
@@ -704,7 +704,8 @@ mod tests {
#[tokio::test]
async fn test_circuit() {
let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (gen, ev) = mock_vm();
let gen = IdealVm::new();
let ev = IdealVm::new();
let share_a0_bytes = [5_u8; 32];
let share_a1_bytes = [2_u8; 32];
@@ -834,16 +835,4 @@ mod tests {
(leader, follower)
}
fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
let gen = Garbler::new(cot_send, [0u8; 16], delta);
let ev = Evaluator::new(cot_recv);
(gen, ev)
}
}

View File

@@ -8,7 +8,7 @@
//! with the server alone and forward all messages from and to the follower.
//!
//! A detailed description of this protocol can be found in our documentation
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>.
//! <https://tlsnotary.org/docs/mpc/key_exchange>.
#![deny(missing_docs, unreachable_pub, unused_must_use)]
#![deny(clippy::all)]

View File

@@ -26,8 +26,7 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
#[cfg(test)]
mod tests {
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
use mpz_ideal_vm::IdealVm;
use super::*;
use crate::KeyExchange;
@@ -40,12 +39,12 @@ mod tests {
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Garbler<IdealCOTSender>,
IdealVm,
>(leader);
is_key_exchange::<
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
Evaluator<IdealCOTReceiver>,
IdealVm,
>(follower);
}
}

View File

@@ -4,7 +4,7 @@
//! protocol has semi-honest security.
//!
//! The protocol is described in
//! <https://docs.tlsnotary.org/protocol/notarization/key_exchange.html>
//! <https://tlsnotary.org/docs/mpc/key_exchange>
use crate::{KeyExchangeError, Role};
use mpz_common::{Context, Flush};

View File

@@ -5,7 +5,7 @@ description = "Core types for TLSNotary"
keywords = ["tls", "mpc", "2pc", "types"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -13,6 +13,7 @@ workspace = true
[features]
default = []
mozilla-certs = ["dep:webpki-root-certs", "dep:webpki-roots"]
fixtures = [
"dep:hex",
"dep:tlsn-data-fixtures",
@@ -26,6 +27,7 @@ tlsn-data-fixtures = { workspace = true, optional = true }
tlsn-tls-core = { workspace = true, features = ["serde"] }
tlsn-utils = { workspace = true }
rangeset = { workspace = true, features = ["serde"] }
mpz-predicate = { workspace = true }
aead = { workspace = true, features = ["alloc"], optional = true }
aes-gcm = { workspace = true, optional = true }
@@ -44,7 +46,8 @@ sha2 = { workspace = true }
thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
web-time = { workspace = true }
webpki-roots = { workspace = true }
webpki-roots = { workspace = true, optional = true }
webpki-root-certs = { workspace = true, optional = true }
rustls-webpki = { workspace = true, features = ["ring"] }
rustls-pki-types = { workspace = true }
itybity = { workspace = true }
@@ -57,5 +60,7 @@ generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -0,0 +1,7 @@
//! Configuration types.
pub mod prove;
pub mod prover;
pub mod tls;
pub mod tls_commit;
pub mod verifier;

View File

@@ -0,0 +1,348 @@
//! Proving configuration.
use mpz_predicate::Pred;
use rangeset::set::{RangeSet, ToRangeSet};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
/// Configuration for a predicate to prove over transcript data.
///
/// A predicate is a boolean constraint that operates on transcript bytes.
/// The prover proves in ZK that the predicate evaluates to true.
///
/// The predicate itself encodes which byte indices it operates on via its
/// atomic comparisons (e.g., `gte(42, threshold)` operates on byte index 42).
#[derive(Debug, Clone)]
pub struct PredicateConfig {
/// Human-readable name for the predicate (sent to verifier as sanity
/// check).
name: String,
/// Direction of transcript data the predicate operates on.
direction: Direction,
/// The predicate to prove.
predicate: Pred,
}
impl PredicateConfig {
/// Creates a new predicate configuration.
///
/// # Arguments
///
/// * `name` - Human-readable name for the predicate.
/// * `direction` - Whether the predicate operates on sent or received data.
/// * `predicate` - The predicate to prove.
pub fn new(name: impl Into<String>, direction: Direction, predicate: Pred) -> Self {
Self {
name: name.into(),
direction,
predicate,
}
}
/// Returns the predicate name.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the direction of transcript data.
pub fn direction(&self) -> Direction {
self.direction
}
/// Returns the predicate.
pub fn predicate(&self) -> &Pred {
&self.predicate
}
/// Returns the transcript byte indices this predicate operates on.
pub fn indices(&self) -> Vec<usize> {
self.predicate.indices()
}
/// Converts to a request (wire format).
pub fn to_request(&self) -> PredicateRequest {
let indices: RangeSet<usize> = self
.predicate
.indices()
.into_iter()
.map(|idx| idx..idx + 1)
.collect();
PredicateRequest {
name: self.name.clone(),
direction: self.direction,
indices,
}
}
}
/// Wire format for predicate proving request.
///
/// Contains only the predicate name and indices - the verifier is expected
/// to know which predicate corresponds to the name from out-of-band agreement.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredicateRequest {
/// Human-readable name for the predicate.
name: String,
/// Direction of transcript data the predicate operates on.
direction: Direction,
/// Transcript byte indices the predicate operates on.
indices: RangeSet<usize>,
}
impl PredicateRequest {
/// Returns the predicate name.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the direction of transcript data.
pub fn direction(&self) -> Direction {
self.direction
}
/// Returns the transcript byte indices as a RangeSet.
pub fn indices(&self) -> &RangeSet<usize> {
&self.indices
}
}
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
predicates: Vec<PredicateConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
/// Returns the predicate configurations.
pub fn predicates(&self) -> &[PredicateConfig] {
&self.predicates
}
/// Returns a request.
pub fn to_request(&self) -> ProveRequest {
ProveRequest {
server_identity: self.server_identity,
reveal: self.reveal.clone(),
transcript_commit: self
.transcript_commit
.clone()
.map(|config| config.to_request()),
predicates: self.predicates.iter().map(|p| p.to_request()).collect(),
}
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
predicates: Vec<PredicateConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
predicates: Vec::new(),
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
}));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigError> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
}
/// Adds a predicate to prove over transcript data.
///
/// The predicate encodes which byte indices it operates on via its atomic
/// comparisons (e.g., `gte(42, threshold)` operates on byte index 42).
///
/// # Arguments
///
/// * `name` - Human-readable name for the predicate (sent to verifier as
/// sanity check).
/// * `direction` - Whether the predicate operates on sent or received data.
/// * `predicate` - The predicate to prove.
pub fn predicate(
&mut self,
name: impl Into<String>,
direction: Direction,
predicate: Pred,
) -> Result<&mut Self, ProveConfigError> {
let indices = predicate.indices();
// Predicate must reference at least one transcript byte.
let last_idx = *indices
.last()
.ok_or(ProveConfigError(ErrorRepr::EmptyPredicate))?;
// Since indices are sorted, only check the last one for bounds.
let transcript_len = self.transcript.len_of_direction(direction);
if last_idx >= transcript_len {
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
direction,
actual: last_idx,
len: transcript_len,
}));
}
self.predicates
.push(PredicateConfig::new(name, direction, predicate));
Ok(self)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
predicates: self.predicates,
})
}
}
/// Request to prove statements about the connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveRequest {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitRequest>,
predicates: Vec<PredicateRequest>,
}
impl ProveRequest {
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the sent and received ranges of the transcript to be revealed,
/// respectively.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitRequest> {
self.transcript_commit.as_ref()
}
/// Returns the predicate requests.
pub fn predicates(&self) -> &[PredicateRequest] {
&self.predicates
}
}
/// Error for [`ProveConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("index out of bounds for {direction} transcript: {actual} >= {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
#[error("predicate must reference at least one transcript byte")]
EmptyPredicate,
}

View File

@@ -0,0 +1,33 @@
//! Prover configuration.
use serde::{Deserialize, Serialize};
/// Prover configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProverConfig {}
impl ProverConfig {
/// Creates a new builder.
pub fn builder() -> ProverConfigBuilder {
ProverConfigBuilder::default()
}
}
/// Builder for [`ProverConfig`].
#[derive(Debug, Default)]
pub struct ProverConfigBuilder {}
impl ProverConfigBuilder {
/// Builds the configuration.
pub fn build(self) -> Result<ProverConfig, ProverConfigError> {
Ok(ProverConfig {})
}
}
/// Error for [`ProverConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProverConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {}

View File

@@ -0,0 +1,111 @@
//! TLS client configuration.
use serde::{Deserialize, Serialize};
use crate::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
/// TLS client configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsClientConfig {
server_name: ServerName,
/// Root certificates.
root_store: RootCertStore,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsClientConfig {
/// Creates a new builder.
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
/// Returns the server name.
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Returns the root certificates.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> Option<&(Vec<CertificateDer>, PrivateKeyDer)> {
self.client_auth.as_ref()
}
}
/// Builder for [`TlsClientConfig`].
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
server_name: Option<ServerName>,
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
/// Sets the server name.
pub fn server_name(mut self, server_name: ServerName) -> Self {
self.server_name = Some(server_name);
self
}
/// Sets the root certificates to use for verifying the server's
/// certificate.
pub fn root_store(mut self, store: RootCertStore) -> Self {
self.root_store = Some(store);
self
}
/// Sets a DER-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> Self {
self.client_auth = Some(cert_key);
self
}
/// Builds the TLS configuration.
pub fn build(self) -> Result<TlsClientConfig, TlsConfigError> {
let server_name = self.server_name.ok_or(ErrorRepr::MissingField {
field: "server_name",
})?;
let root_store = self.root_store.ok_or(ErrorRepr::MissingField {
field: "root_store",
})?;
Ok(TlsClientConfig {
server_name,
root_store,
client_auth: self.client_auth,
})
}
}
/// TLS configuration error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error")]
enum ErrorRepr {
#[error("missing required field: {field}")]
MissingField { field: &'static str },
}

View File

@@ -0,0 +1,94 @@
//! TLS commitment configuration.
pub mod mpc;
use serde::{Deserialize, Serialize};
/// TLS commitment configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfig {
protocol: TlsCommitProtocolConfig,
}
impl TlsCommitConfig {
/// Creates a new builder.
pub fn builder() -> TlsCommitConfigBuilder {
TlsCommitConfigBuilder::default()
}
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.protocol
}
/// Returns a TLS commitment request.
pub fn to_request(&self) -> TlsCommitRequest {
TlsCommitRequest {
config: self.protocol.clone(),
}
}
}
/// Builder for [`TlsCommitConfig`].
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TlsCommitConfigBuilder {
protocol: Option<TlsCommitProtocolConfig>,
}
impl TlsCommitConfigBuilder {
/// Sets the protocol configuration.
pub fn protocol<C>(mut self, protocol: C) -> Self
where
C: Into<TlsCommitProtocolConfig>,
{
self.protocol = Some(protocol.into());
self
}
/// Builds the configuration.
pub fn build(self) -> Result<TlsCommitConfig, TlsCommitConfigError> {
let protocol = self
.protocol
.ok_or(ErrorRepr::MissingField { name: "protocol" })?;
Ok(TlsCommitConfig { protocol })
}
}
/// TLS commitment protocol configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TlsCommitProtocolConfig {
/// MPC-TLS configuration.
Mpc(mpc::MpcTlsConfig),
}
impl From<mpc::MpcTlsConfig> for TlsCommitProtocolConfig {
fn from(config: mpc::MpcTlsConfig) -> Self {
Self::Mpc(config)
}
}
/// TLS commitment request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsCommitRequest {
config: TlsCommitProtocolConfig,
}
impl TlsCommitRequest {
/// Returns the protocol configuration.
pub fn protocol(&self) -> &TlsCommitProtocolConfig {
&self.config
}
}
/// Error for [`TlsCommitConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsCommitConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -0,0 +1,241 @@
//! MPC-TLS commitment protocol configuration.
use serde::{Deserialize, Serialize};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
/// MPC-TLS commitment protocol configuration.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(try_from = "unchecked::MpcTlsConfigUnchecked")]
pub struct MpcTlsConfig {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
max_sent_records: Option<usize>,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
defer_decryption_from_start: bool,
/// Network settings.
network: NetworkSetting,
}
impl MpcTlsConfig {
/// Creates a new builder.
pub fn builder() -> MpcTlsConfigBuilder {
MpcTlsConfigBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> Option<usize> {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
fn validate(config: MpcTlsConfig) -> Result<MpcTlsConfig, MpcTlsConfigError> {
if config.max_recv_data_online > config.max_recv_data {
return Err(ErrorRepr::InvalidValue {
name: "max_recv_data_online",
reason: format!(
"must be <= max_recv_data ({} > {})",
config.max_recv_data_online, config.max_recv_data
),
}
.into());
}
Ok(config)
}
/// Builder for [`MpcTlsConfig`].
#[derive(Debug, Default)]
pub struct MpcTlsConfigBuilder {
max_sent_data: Option<usize>,
max_sent_records: Option<usize>,
max_recv_data_online: Option<usize>,
max_recv_data: Option<usize>,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: Option<bool>,
network: Option<NetworkSetting>,
}
impl MpcTlsConfigBuilder {
/// Sets the maximum number of bytes that can be sent.
pub fn max_sent_data(mut self, max_sent_data: usize) -> Self {
self.max_sent_data = Some(max_sent_data);
self
}
/// Sets the maximum number of application data records that can be sent.
pub fn max_sent_records(mut self, max_sent_records: usize) -> Self {
self.max_sent_records = Some(max_sent_records);
self
}
/// Sets the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(mut self, max_recv_data_online: usize) -> Self {
self.max_recv_data_online = Some(max_recv_data_online);
self
}
/// Sets the maximum number of bytes that can be received.
pub fn max_recv_data(mut self, max_recv_data: usize) -> Self {
self.max_recv_data = Some(max_recv_data);
self
}
/// Sets the maximum number of received application data records that can
/// be decrypted online.
pub fn max_recv_records_online(mut self, max_recv_records_online: usize) -> Self {
self.max_recv_records_online = Some(max_recv_records_online);
self
}
/// Sets whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(mut self, defer_decryption_from_start: bool) -> Self {
self.defer_decryption_from_start = Some(defer_decryption_from_start);
self
}
/// Sets the network settings.
pub fn network(mut self, network: NetworkSetting) -> Self {
self.network = Some(network);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<MpcTlsConfig, MpcTlsConfigError> {
let Self {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
} = self;
let max_sent_data = max_sent_data.ok_or(ErrorRepr::MissingField {
name: "max_sent_data",
})?;
let max_recv_data_online = max_recv_data_online.unwrap_or(DEFAULT_MAX_RECV_ONLINE);
let max_recv_data = max_recv_data.ok_or(ErrorRepr::MissingField {
name: "max_recv_data",
})?;
let defer_decryption_from_start = defer_decryption_from_start.unwrap_or(true);
let network = network.unwrap_or_default();
validate(MpcTlsConfig {
max_sent_data,
max_sent_records,
max_recv_data_online,
max_recv_data,
max_recv_records_online,
defer_decryption_from_start,
network,
})
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
#[default]
Latency,
}
/// Error for [`MpcTlsConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct MpcTlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
#[error("invalid value for field({name}): {reason}")]
InvalidValue { name: &'static str, reason: String },
}
mod unchecked {
use super::*;
#[derive(Deserialize)]
pub(super) struct MpcTlsConfigUnchecked {
max_sent_data: usize,
max_sent_records: Option<usize>,
max_recv_data_online: usize,
max_recv_data: usize,
max_recv_records_online: Option<usize>,
defer_decryption_from_start: bool,
network: NetworkSetting,
}
impl TryFrom<MpcTlsConfigUnchecked> for MpcTlsConfig {
type Error = MpcTlsConfigError;
fn try_from(value: MpcTlsConfigUnchecked) -> Result<Self, Self::Error> {
validate(MpcTlsConfig {
max_sent_data: value.max_sent_data,
max_sent_records: value.max_sent_records,
max_recv_data_online: value.max_recv_data_online,
max_recv_data: value.max_recv_data,
max_recv_records_online: value.max_recv_records_online,
defer_decryption_from_start: value.defer_decryption_from_start,
network: value.network,
})
}
}
}

View File

@@ -0,0 +1,56 @@
//! Verifier configuration.
use serde::{Deserialize, Serialize};
use crate::webpki::RootCertStore;
/// Verifier configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerifierConfig {
root_store: RootCertStore,
}
impl VerifierConfig {
/// Creates a new builder.
pub fn builder() -> VerifierConfigBuilder {
VerifierConfigBuilder::default()
}
/// Returns the root certificate store.
pub fn root_store(&self) -> &RootCertStore {
&self.root_store
}
}
/// Builder for [`VerifierConfig`].
#[derive(Debug, Default)]
pub struct VerifierConfigBuilder {
root_store: Option<RootCertStore>,
}
impl VerifierConfigBuilder {
/// Sets the root certificate store.
pub fn root_store(mut self, root_store: RootCertStore) -> Self {
self.root_store = Some(root_store);
self
}
/// Builds the configuration.
pub fn build(self) -> Result<VerifierConfig, VerifierConfigError> {
let root_store = self
.root_store
.ok_or(ErrorRepr::MissingField { name: "root_store" })?;
Ok(VerifierConfig { root_store })
}
}
/// Error for [`VerifierConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct VerifierConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("missing field: {name}")]
MissingField { name: &'static str },
}

View File

@@ -1,11 +1,11 @@
use rangeset::RangeSet;
use rangeset::set::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter_ranges() {
for range in self.0.iter() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;

View File

@@ -2,12 +2,13 @@
use aead::Payload as AeadPayload;
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
#[allow(deprecated)]
use generic_array::GenericArray;
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_core::msgs::{
base::Payload,
codec::Codec,
enums::{ContentType, HandshakeType, ProtocolVersion},
enums::{HandshakeType, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
message::{OpaqueMessage, PlainMessage},
};
@@ -15,7 +16,7 @@ use tls_core::msgs::{
use crate::{
connection::{TranscriptLength, VerifyData},
fixtures::ConnectionFixture,
transcript::{Record, TlsTranscript},
transcript::{ContentType, Record, TlsTranscript},
};
/// The key used for encryption of the sent and received transcript.
@@ -103,7 +104,7 @@ impl TranscriptGenerator {
let explicit_nonce: [u8; 8] = seq.to_be_bytes();
let msg = PlainMessage {
typ: ContentType::ApplicationData,
typ: ContentType::ApplicationData.into(),
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext),
};
@@ -138,7 +139,7 @@ impl TranscriptGenerator {
handshake_message.encode(&mut plaintext);
let msg = PlainMessage {
typ: ContentType::Handshake,
typ: ContentType::Handshake.into(),
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(plaintext.clone()),
};
@@ -180,6 +181,7 @@ fn aes_gcm_encrypt(
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();

View File

@@ -296,14 +296,14 @@ mod sha2 {
fn hash(&self, data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(data);
super::Hash::new(hasher.finalize().as_slice())
super::Hash::new(hasher.finalize().as_ref())
}
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(prefix);
hasher.update(data);
super::Hash::new(hasher.finalize().as_slice())
super::Hash::new(hasher.finalize().as_ref())
}
}
}

View File

@@ -12,212 +12,18 @@ pub mod merkle;
pub mod transcript;
pub mod webpki;
pub use rangeset;
pub mod config;
pub(crate) mod display;
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
connection::{HandshakeData, ServerName},
connection::ServerName,
transcript::{
encoding::EncoderSecret, Direction, PartialTranscript, Transcript, TranscriptCommitConfig,
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
encoding::EncoderSecret, PartialTranscript, TranscriptCommitment, TranscriptSecret,
},
};
/// Configuration to prove information to the verifier.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProveConfig {
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl ProveConfig {
/// Creates a new builder.
pub fn builder(transcript: &Transcript) -> ProveConfigBuilder<'_> {
ProveConfigBuilder::new(transcript)
}
/// Returns `true` if the server identity is to be proven.
pub fn server_identity(&self) -> bool {
self.server_identity
}
/// Returns the ranges of the transcript to be revealed.
pub fn reveal(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.reveal.as_ref()
}
/// Returns the transcript commitment configuration.
pub fn transcript_commit(&self) -> Option<&TranscriptCommitConfig> {
self.transcript_commit.as_ref()
}
}
/// Builder for [`ProveConfig`].
#[derive(Debug)]
pub struct ProveConfigBuilder<'a> {
transcript: &'a Transcript,
server_identity: bool,
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
transcript_commit: Option<TranscriptCommitConfig>,
}
impl<'a> ProveConfigBuilder<'a> {
/// Creates a new builder.
pub fn new(transcript: &'a Transcript) -> Self {
Self {
transcript,
server_identity: false,
reveal: None,
transcript_commit: None,
}
}
/// Proves the server identity.
pub fn server_identity(&mut self) -> &mut Self {
self.server_identity = true;
self
}
/// Configures transcript commitments.
pub fn transcript_commit(&mut self, transcript_commit: TranscriptCommitConfig) -> &mut Self {
self.transcript_commit = Some(transcript_commit);
self
}
/// Reveals the given ranges of the transcript.
pub fn reveal(
&mut self,
direction: Direction,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
let idx = ranges.to_range_set();
if idx.end().unwrap_or(0) > self.transcript.len_of_direction(direction) {
return Err(ProveConfigBuilderError(
ProveConfigBuilderErrorRepr::IndexOutOfBounds {
direction,
actual: idx.end().unwrap_or(0),
len: self.transcript.len_of_direction(direction),
},
));
}
let (sent, recv) = self.reveal.get_or_insert_default();
match direction {
Direction::Sent => sent.union_mut(&idx),
Direction::Received => recv.union_mut(&idx),
}
Ok(self)
}
/// Reveals the given ranges of the sent data transcript.
pub fn reveal_sent(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
ranges: &dyn ToRangeSet<usize>,
) -> Result<&mut Self, ProveConfigBuilderError> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
Ok(ProveConfig {
server_identity: self.server_identity,
reveal: self.reveal,
transcript_commit: self.transcript_commit,
})
}
}
/// Error for [`ProveConfigBuilder`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ProveConfigBuilderError(#[from] ProveConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum ProveConfigBuilderErrorRepr {
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
IndexOutOfBounds {
direction: Direction,
actual: usize,
len: usize,
},
}
/// Configuration to verify information from the prover.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct VerifyConfig {}
impl VerifyConfig {
/// Creates a new builder.
pub fn builder() -> VerifyConfigBuilder {
VerifyConfigBuilder::new()
}
}
/// Builder for [`VerifyConfig`].
#[derive(Debug, Default)]
pub struct VerifyConfigBuilder {}
impl VerifyConfigBuilder {
/// Creates a new builder.
pub fn new() -> Self {
Self {}
}
/// Builds the configuration.
pub fn build(self) -> Result<VerifyConfig, VerifyConfigBuilderError> {
Ok(VerifyConfig {})
}
}
/// Error for [`VerifyConfigBuilder`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Request to prove statements about the connection.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Transcript commitment configuration.
pub transcript_commit: Option<TranscriptCommitRequest>,
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {

View File

@@ -26,7 +26,11 @@ mod tls;
use std::{fmt, ops::Range};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -38,8 +42,7 @@ pub use commit::{
pub use proof::{
TranscriptProof, TranscriptProofBuilder, TranscriptProofBuilderError, TranscriptProofError,
};
pub use tls::{Record, TlsTranscript};
pub use tls_core::msgs::enums::ContentType;
pub use tls::{ContentType, Record, TlsTranscript};
/// A transcript contains the plaintext of all application data communicated
/// between the Prover and the Server.
@@ -107,7 +110,7 @@ impl Transcript {
}
Some(
Subsequence::new(idx.clone(), data.index_ranges(idx))
Subsequence::new(idx.clone(), data.index(idx).flatten().copied().collect())
.expect("data is same length as index"),
)
}
@@ -130,11 +133,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
for range in sent_idx.iter_ranges() {
for range in sent_idx.iter() {
sent[range.clone()].copy_from_slice(&self.sent[range]);
}
for range in recv_idx.iter_ranges() {
for range in recv_idx.iter() {
received[range.clone()].copy_from_slice(&self.received[range]);
}
@@ -189,10 +192,16 @@ impl From<PartialTranscript> for CompressedPartialTranscript {
Self {
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx),
.index(&uncompressed.sent_authed_idx)
.flatten()
.copied()
.collect(),
received_authed: uncompressed
.received
.index_ranges(&uncompressed.received_authed_idx),
.index(&uncompressed.received_authed_idx)
.flatten()
.copied()
.collect(),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -208,7 +217,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.sent_idx.iter_ranges() {
for range in compressed.sent_idx.iter() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
@@ -216,7 +225,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.recv_idx.iter_ranges() {
for range in compressed.recv_idx.iter() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
@@ -305,12 +314,16 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len()).difference(&self.sent_authed_idx)
(0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len()).difference(&self.received_authed_idx)
(0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -320,7 +333,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.iter().map(|i| data[i])
authed.iter_values().map(move |i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -340,24 +353,20 @@ impl PartialTranscript {
"received data are not the same length"
);
for range in other
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed_idx
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx
.union_mut(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -369,11 +378,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
self.sent_authed_idx.union_mut(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
self.received_authed_idx.union_mut(&seq.idx);
}
}
}
@@ -384,10 +393,10 @@ impl PartialTranscript {
///
/// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter_ranges() {
for range in self.sent_unauthed().iter() {
self.sent[range].fill(value);
}
for range in self.received_unauthed().iter_ranges() {
for range in self.received_unauthed().iter() {
self.received[range].fill(value);
}
}
@@ -402,13 +411,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
for r in range.difference(&self.sent_authed_idx) {
self.sent[r].fill(value);
}
}
Direction::Received => {
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
for r in range.difference(&self.received_authed_idx) {
self.received[r].fill(value);
}
}
}
@@ -486,7 +495,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0;
for range in self.idx.iter_ranges() {
for range in self.idx.iter() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len();
}
@@ -611,12 +620,7 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
partial_transcript.sent_total = end - 1;

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt};
use rangeset::{ToRangeSet, UnionMut};
use rangeset::set::ToRangeSet;
use serde::{Deserialize, Serialize};
use crate::{

View File

@@ -1,6 +1,6 @@
use std::{collections::HashMap, fmt};
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use crate::{
@@ -103,7 +103,7 @@ impl EncodingProof {
}
expected_leaf.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
}
expected_leaf.extend_from_slice(blinder.as_bytes());

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use crate::{
@@ -99,7 +99,7 @@ impl EncodingTree {
let blinder: Blinder = rand::random();
encoding.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
provider
.provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;

View File

@@ -1,6 +1,10 @@
//! Transcript proofs.
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
@@ -25,6 +29,9 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::KECCAK256,
},
TranscriptCommitmentKind::Encoding,
];
@@ -141,7 +148,7 @@ impl TranscriptProof {
}
buffer.clear();
for range in idx.iter_ranges() {
for range in idx.iter() {
buffer.extend_from_slice(&plaintext[range]);
}
@@ -363,7 +370,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx);
} else {
let missing = idx.difference(committed);
let missing = idx.difference(committed).into_set();
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!(
@@ -579,7 +586,7 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rangeset::RangeSet;
use rangeset::prelude::*;
use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
@@ -656,6 +663,7 @@ mod tests {
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
@@ -704,6 +712,7 @@ mod tests {
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
#[case::keccak256(HashAlgId::KECCAK256)]
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();

View File

@@ -10,10 +10,53 @@ use crate::{
use tls_core::msgs::{
alert::AlertMessagePayload,
codec::{Codec, Reader},
enums::{AlertDescription, ContentType, ProtocolVersion},
enums::{AlertDescription, ProtocolVersion},
handshake::{HandshakeMessagePayload, HandshakePayload},
};
/// TLS record content type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ContentType {
/// Change cipher spec protocol.
ChangeCipherSpec,
/// Alert protocol.
Alert,
/// Handshake protocol.
Handshake,
/// Application data protocol.
ApplicationData,
/// Heartbeat protocol.
Heartbeat,
/// Unknown protocol.
Unknown(u8),
}
impl From<ContentType> for tls_core::msgs::enums::ContentType {
fn from(content_type: ContentType) -> Self {
match content_type {
ContentType::ChangeCipherSpec => tls_core::msgs::enums::ContentType::ChangeCipherSpec,
ContentType::Alert => tls_core::msgs::enums::ContentType::Alert,
ContentType::Handshake => tls_core::msgs::enums::ContentType::Handshake,
ContentType::ApplicationData => tls_core::msgs::enums::ContentType::ApplicationData,
ContentType::Heartbeat => tls_core::msgs::enums::ContentType::Heartbeat,
ContentType::Unknown(id) => tls_core::msgs::enums::ContentType::Unknown(id),
}
}
}
impl From<tls_core::msgs::enums::ContentType> for ContentType {
fn from(content_type: tls_core::msgs::enums::ContentType) -> Self {
match content_type {
tls_core::msgs::enums::ContentType::ChangeCipherSpec => ContentType::ChangeCipherSpec,
tls_core::msgs::enums::ContentType::Alert => ContentType::Alert,
tls_core::msgs::enums::ContentType::Handshake => ContentType::Handshake,
tls_core::msgs::enums::ContentType::ApplicationData => ContentType::ApplicationData,
tls_core::msgs::enums::ContentType::Heartbeat => ContentType::Heartbeat,
tls_core::msgs::enums::ContentType::Unknown(id) => ContentType::Unknown(id),
}
}
}
/// A transcript of TLS records sent and received by the prover.
#[derive(Debug, Clone)]
pub struct TlsTranscript {

View File

@@ -53,6 +53,21 @@ impl RootCertStore {
pub fn empty() -> Self {
Self { roots: Vec::new() }
}
/// Creates a root certificate store with Mozilla root certificates.
///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
pub fn mozilla() -> Self {
Self {
roots: webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|cert| CertificateDer(cert.to_vec()))
.collect(),
}
}
}
/// Server certificate verifier.
@@ -82,8 +97,12 @@ impl ServerCertVerifier {
Ok(Self { roots })
}
/// Creates a new server certificate verifier with Mozilla root
/// certificates.
/// Creates a server certificate verifier with Mozilla root certificates.
///
/// These certificates are sourced from [`webpki-root-certs`](https://docs.rs/webpki-root-certs/latest/webpki_root_certs/). It is not recommended to use these unless the
/// application binary can be recompiled and deployed on-demand in the case
/// that the root certificates need to be updated.
#[cfg(feature = "mozilla-certs")]
pub fn mozilla() -> Self {
Self {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),

View File

@@ -10,11 +10,15 @@ workspace = true
[dependencies]
tlsn = { workspace = true }
tlsn-formats = { workspace = true }
tlsn-core = { workspace = true }
tls-server-fixture = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
spansy = { workspace = true }
mpz-predicate = { workspace = true }
rangeset = { workspace = true }
anyhow = { workspace = true }
bincode = { workspace = true }
chrono = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
@@ -37,7 +41,9 @@ tokio = { workspace = true, features = [
tokio-util = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = ["barretenberg"] }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = [
"barretenberg",
] }
[[example]]
name = "interactive"
@@ -58,3 +64,7 @@ path = "attestation/present.rs"
[[example]]
name = "attestation_verify"
path = "attestation/verify.rs"
[[example]]
name = "interactive_predicate"
path = "interactive_predicate/interactive_predicate.rs"

View File

@@ -4,5 +4,8 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
* [Interactive_zk](./interactive_zk/README.md): Interactive Prover and Verifier session demonstrating zero-knowledge age verification using Noir.
* [Interactive_predicate](./interactive_predicate/README.md): Interactive session demonstrating predicate proving over transcript data (e.g., proving a JSON field is a valid string without revealing it).
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.

View File

@@ -4,6 +4,7 @@
use std::env;
use anyhow::{anyhow, Result};
use clap::Parser;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode};
@@ -23,12 +24,17 @@ use tlsn::{
Attestation, AttestationConfig, CryptoProvider, Secrets,
},
config::{
CertificateDer, PrivateKeyDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore,
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
verifier::VerifierConfig,
},
connection::{ConnectionInfo, HandshakeData, ServerName, TranscriptLength},
prover::{state::Committed, ProveConfig, Prover, ProverConfig, ProverOutput, TlsConfig},
prover::{state::Committed, Prover, ProverOutput},
transcript::{ContentType, TranscriptCommitConfig},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tlsn_examples::ExampleType;
use tlsn_formats::http::{DefaultHttpCommitter, HttpCommit, HttpTranscript};
@@ -47,7 +53,7 @@ struct Args {
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let args = Args::parse();
@@ -87,64 +93,63 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
uri: &str,
extra_headers: Vec<(&str, &str)>,
example_type: &ExampleType,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<()> {
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
));
let tls_config = tls_config_builder.build().unwrap();
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
// We must configure the amount of data we expect to exchange beforehand, which will
// be preprocessed prior to the connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create a new prover and perform necessary setup.
let prover = Prover::new(prover_config).setup(socket.compat()).await?;
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
.build()?,
socket.compat(),
)
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect((server_host, server_port)).await?;
// Bind the prover to the server connection.
// The returned `mpc_tls_connection` is an MPC TLS connection to the server: all
// data written to/read from it will be encrypted/decrypted using MPC with
// the notary.
let (mpc_tls_connection, prover_fut) = prover.connect(client_socket.compat()).await?;
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
// (Optional) Set up TLS client authentication if required by the server.
.client_auth((
vec![CertificateDer(CLIENT_CERT_DER.to_vec())],
PrivateKeyDer(CLIENT_KEY_DER.to_vec()),
))
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the prover task to be run concurrently in the background.
let prover_task = tokio::spawn(prover_fut);
// Attach the hyper HTTP client to the connection.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the HTTP task to be run concurrently in the background.
tokio::spawn(connection);
@@ -165,7 +170,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}
let request = request_builder.body(Empty::<Bytes>::new())?;
info!("Starting an MPC TLS connection with the server");
info!("Starting connection with the server");
// Send the request to the server and wait for the response.
let response = request_sender.send_request(request).await?;
@@ -175,7 +180,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it.
let mut prover = prover_task.await??;
let prover = prover_task.await??;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?;
@@ -217,7 +222,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let request_config = builder.build()?;
let (attestation, secrets) = notarize(&mut prover, &request_config, req_tx, resp_rx).await?;
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
// Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
@@ -238,11 +243,11 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}
async fn notarize(
prover: &mut Prover<Committed>,
mut prover: Prover<Committed>,
config: &RequestConfig,
request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>,
) -> Result<(Attestation, Secrets), Box<dyn std::error::Error>> {
) -> Result<(Attestation, Secrets)> {
let mut builder = ProveConfig::builder(prover.transcript());
if let Some(config) = config.transcript_commit() {
@@ -257,25 +262,27 @@ async fn notarize(
..
} = prover.prove(&disclosure_config).await?;
let transcript = prover.transcript().clone();
let tls_transcript = prover.tls_transcript().clone();
prover.close().await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);
builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.handshake_data(HandshakeData {
certs: prover
.tls_transcript()
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: prover
.tls_transcript()
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: prover.tls_transcript().certificate_binding().clone(),
binding: tls_transcript.certificate_binding().clone(),
})
.transcript(prover.transcript().clone())
.transcript(transcript)
.transcript_commitments(transcript_secrets, transcript_commitments);
let (request, secrets) = builder.build(&CryptoProvider::default())?;
@@ -283,15 +290,18 @@ async fn notarize(
// Send attestation request to notary.
request_tx
.send(request.clone())
.map_err(|_| "notary is not receiving attestation request".to_string())?;
.map_err(|_| anyhow!("notary is not receiving attestation request"))?;
// Receive attestation from notary.
let attestation = attestation_rx
.await
.map_err(|err| format!("notary did not respond with attestation: {err}"))?;
.map_err(|err| anyhow!("notary did not respond with attestation: {err}"))?;
// Signature verifier for the signature algorithm in the request.
let provider = CryptoProvider::default();
// Check the attestation is consistent with the Prover's view.
request.validate(&attestation)?;
request.validate(&attestation, &provider)?;
Ok((attestation, secrets))
}
@@ -300,14 +310,7 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: S,
request_rx: Receiver<AttestationRequest>,
attestation_tx: Sender<Attestation>,
) -> Result<(), Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()
.unwrap();
) -> Result<()> {
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -315,20 +318,25 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
let mut verifier = Verifier::new(verifier_config)
.setup(socket.compat())
let verifier = Verifier::new(verifier_config)
.commit(socket.compat())
.await?
.accept()
.await?
.run()
.await?;
let VerifierOutput {
transcript_commitments,
..
} = verifier.verify(&VerifyConfig::default()).await?;
let (
VerifierOutput {
transcript_commitments,
encoder_secret,
..
},
verifier,
) = verifier.verify().await?.accept().await?;
let tls_transcript = verifier.tls_transcript().clone();
@@ -385,12 +393,16 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments);
if let Some(encoder_secret) = encoder_secret {
builder.encoder_secret(encoder_secret);
}
let attestation = builder.build(&provider)?;
// Send attestation to prover.
attestation_tx
.send(attestation)
.map_err(|_| "prover is not receiving attestation".to_string())?;
.map_err(|_| anyhow!("prover is not receiving attestation"))?;
Ok(())
}

View File

@@ -12,8 +12,8 @@ use tlsn::{
signing::VerifyingKey,
CryptoProvider,
},
config::{CertificateDer, RootCertStore},
verifier::ServerCertVerifier,
webpki::{CertificateDer, RootCertStore},
};
use tlsn_examples::ExampleType;
use tlsn_server_fixture_certs::CA_CERT_DER;

View File

@@ -3,6 +3,7 @@ use std::{
net::{IpAddr, SocketAddr},
};
use anyhow::Result;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
@@ -11,11 +12,18 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
verifier::VerifierConfig,
},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
prover::Prover,
transcript::PartialTranscript,
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -46,7 +54,7 @@ async fn main() {
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let prover = prover(prover_socket, &server_addr, &uri);
let verifier = verifier(verifier_socket);
let (_, transcript) = tokio::join!(prover, verifier);
let (_, transcript) = tokio::try_join!(prover, verifier).unwrap();
println!("Successfully verified {}", &uri);
println!(
@@ -64,61 +72,57 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
server_addr: &SocketAddr,
uri: &str,
) {
) -> Result<()> {
let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
.build()?,
verifier_socket.compat(),
)
.await?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into().unwrap()))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap(),
);
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
let prover_config = prover_config_builder.build().unwrap();
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await
.unwrap();
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await.unwrap();
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) =
prover.connect(tls_client_socket.compat()).await.unwrap();
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection)
.await
.unwrap();
hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
@@ -130,14 +134,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.header("Connection", "close")
.header("Secret", SECRET)
.method("GET")
.body(Empty::<Bytes>::new())
.unwrap();
let response = request_sender.send_request(request).await.unwrap();
.body(Empty::<Bytes>::new())?;
let response = request_sender.send_request(request).await?;
assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier.
let mut prover = prover_task.await.unwrap().unwrap();
let mut prover = prover_task.await??;
let mut builder = ProveConfig::builder(prover.transcript());
@@ -153,10 +156,8 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.expect("the secret should be in the sent data");
// Reveal everything except for the secret.
builder.reveal_sent(&(0..pos)).unwrap();
builder
.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))
.unwrap();
builder.reveal_sent(&(0..pos))?;
builder.reveal_sent(&(pos + SECRET.len()..prover.transcript().sent().len()))?;
// Find the substring "Dick".
let pos = prover
@@ -167,28 +168,21 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
.expect("the substring 'Dick' should be in the received data");
// Reveal everything except for the substring.
builder.reveal_recv(&(0..pos)).unwrap();
builder
.reveal_recv(&(pos + 4..prover.transcript().received().len()))
.unwrap();
builder.reveal_recv(&(0..pos))?;
builder.reveal_recv(&(pos + 4..prover.transcript().received().len()))?;
let config = builder.build().unwrap();
let config = builder.build()?;
prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
prover.prove(&config).await?;
prover.close().await?;
Ok(())
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> PartialTranscript {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();
) -> Result<PartialTranscript> {
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
@@ -196,20 +190,56 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap();
.build()?;
let verifier = Verifier::new(verifier_config);
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await
.unwrap();
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?;
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
{
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else {
Some("expecting to use MPC-TLS")
};
if reject.is_some() {
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify().await?;
if !verifier.request().server_identity() {
let verifier = verifier
.reject(Some("expecting to verify the server name"))
.await?;
verifier.close().await?;
return Err(anyhow::anyhow!("prover did not reveal the server name"));
}
let (
VerifierOutput {
server_name,
transcript,
..
},
verifier,
) = verifier.accept().await?;
verifier.close().await?;
let server_name = server_name.expect("prover should have revealed server name");
let transcript = transcript.expect("prover should have revealed transcript data");
@@ -232,7 +262,7 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let ServerName::Dns(server_name) = server_name;
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
transcript
Ok(transcript)
}
/// Render redacted bytes as `🙈`.

View File

@@ -0,0 +1,29 @@
## Interactive Predicate: Proving Predicates over Transcript Data
This example demonstrates how to use TLSNotary to prove predicates (boolean constraints) over transcript bytes in zero knowledge, without revealing the actual data.
In this example:
- The server returns JSON data containing a "name" field with a string value
- The Prover proves that the name value is a valid JSON string without revealing it
- The Verifier learns that the string is valid JSON, but not the actual content
This uses `mpz_predicate` to build predicates that operate on transcript bytes. The predicate is compiled to a circuit and executed in the ZK VM to prove satisfaction.
### Running the Example
First, start the test server from the root of this repository:
```shell
RUST_LOG=info PORT=4000 cargo run --bin tlsn-server-fixture
```
Next, run the interactive predicate example:
```shell
SERVER_PORT=4000 cargo run --release --example interactive_predicate
```
To view more detailed debug information:
```shell
RUST_LOG=debug,yamux=info,uid_mux=info SERVER_PORT=4000 cargo run --release --example interactive_predicate
```
> Note: In this example, the Prover and Verifier run on the same machine. In real-world scenarios, they would typically operate on separate machines.

View File

@@ -0,0 +1,368 @@
//! Example demonstrating predicate proving over transcript data.
//!
//! This example shows how a prover can prove a predicate (boolean constraint)
//! over transcript bytes in zero knowledge, without revealing the actual data.
//!
//! In this example:
//! - The server returns JSON data containing a "name" field with a string value
//! - The prover proves that the name value is a valid JSON string without
//! revealing it
//! - The verifier learns that the string is valid JSON, but not the actual
//! content
use std::{
env,
net::{IpAddr, SocketAddr},
};
use anyhow::Result;
use http_body_util::Empty;
use hyper::{body::Bytes, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use mpz_predicate::{json::validate_string, Pred};
use rangeset::prelude::RangeSet;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
use tlsn::{
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
verifier::VerifierConfig,
},
connection::ServerName,
prover::Prover,
transcript::Direction,
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
/// Predicate name for JSON string validation (both parties agree on this
/// out-of-band).
const JSON_STRING_PREDICATE: &str = "valid_json_string";
// Maximum number of bytes that can be sent from prover to server.
const MAX_SENT_DATA: usize = 1 << 12;
// Maximum number of bytes that can be received by prover from server.
const MAX_RECV_DATA: usize = 1 << 14;
/// Builds a predicate that validates a JSON string at the given indices.
///
/// Uses mpz_predicate's `validate_string` to ensure the bytes form a valid
/// JSON string (proper escaping, valid UTF-8, no control characters, etc.).
fn build_json_string_predicate(indices: &RangeSet<usize>) -> Pred {
validate_string(indices.clone())
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// Use the JSON endpoint that returns data.
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/formats/json");
let server_ip: IpAddr = server_host.parse().expect("Invalid IP address");
let server_addr = SocketAddr::from((server_ip, server_port));
// Connect prover and verifier.
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let prover = prover(prover_socket, &server_addr, &uri);
let verifier = verifier(verifier_socket);
match tokio::try_join!(prover, verifier) {
Ok(_) => println!("\nSuccess! The prover proved that a JSON field contains a valid string without revealing it."),
Err(e) => eprintln!("Error: {e}"),
}
}
/// Finds the value of a JSON field in the response body.
/// Returns (start_index, end_index) of the value (excluding quotes for
/// strings).
fn find_json_string_value(data: &[u8], field_name: &str) -> Option<(usize, usize)> {
let search_pattern = format!("\"{}\":", field_name);
let pattern_bytes = search_pattern.as_bytes();
// Find the field name
let field_pos = data
.windows(pattern_bytes.len())
.position(|w| w == pattern_bytes)?;
// Skip past the field name and colon
let mut pos = field_pos + pattern_bytes.len();
// Skip whitespace
while pos < data.len() && (data[pos] == b' ' || data[pos] == b'\n' || data[pos] == b'\r') {
pos += 1;
}
// Check if it's a string (starts with quote)
if pos >= data.len() || data[pos] != b'"' {
return None;
}
// Skip opening quote
let start = pos + 1;
// Find closing quote (handling escapes)
let mut end = start;
while end < data.len() {
if data[end] == b'\\' {
// Skip escaped character
end += 2;
} else if data[end] == b'"' {
break;
} else {
end += 1;
}
}
Some((start, end))
}
#[instrument(skip(verifier_socket))]
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<()> {
let uri = uri.parse::<Uri>().unwrap();
assert_eq!(uri.scheme().unwrap().as_str(), "https");
let server_domain = uri.authority().unwrap().host();
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
.protocol(
MpcTlsConfig::builder()
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
.build()?,
)
.build()?,
verifier_socket.compat(),
)
.await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
// Send request for JSON data.
let request = Request::builder()
.uri(uri.clone())
.header("Host", server_domain)
.header("Connection", "close")
.method("GET")
.body(Empty::<Bytes>::new())?;
let response = request_sender.send_request(request).await?;
assert!(response.status() == StatusCode::OK);
// Create proof for the Verifier.
let mut prover = prover_task.await??;
// Find the "name" field value in the JSON response
let received = prover.transcript().received();
// Find the HTTP body (after \r\n\r\n)
let body_start = received
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.unwrap_or(0);
// Find the "name" field's string value
let (value_start, value_end) =
find_json_string_value(&received[body_start..], "name").expect("should find name field");
// Adjust to absolute positions in transcript
let value_start = body_start + value_start;
let value_end = body_start + value_end;
let value_bytes = &received[value_start..value_end];
println!(
"Prover: Found 'name' field value: \"{}\" at positions {}..{}",
String::from_utf8_lossy(value_bytes),
value_start,
value_end
);
println!("Prover: Will prove this is a valid JSON string without revealing the actual content");
// Build indices for the predicate as a RangeSet
let indices: RangeSet<usize> = (value_start..value_end).into();
// Build the predicate using mpz_predicate
let predicate = build_json_string_predicate(&indices);
let mut builder = ProveConfig::builder(prover.transcript());
// Reveal the server identity.
builder.server_identity();
// Reveal the sent data (the request).
builder.reveal_sent(&(0..prover.transcript().sent().len()))?;
// Reveal everything EXCEPT the string value we're proving the predicate over.
if value_start > 0 {
builder.reveal_recv(&(0..value_start))?;
}
if value_end < prover.transcript().received().len() {
builder.reveal_recv(&(value_end..prover.transcript().received().len()))?;
}
// Add the predicate to prove the string is valid JSON without revealing the
// value.
builder.predicate(JSON_STRING_PREDICATE, Direction::Received, predicate)?;
let config = builder.build()?;
prover.prove(&config).await?;
prover.close().await?;
println!("Prover: Successfully proved the predicate!");
Ok(())
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> Result<()> {
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?;
let verifier = Verifier::new(verifier_config);
// Validate the proposed configuration and run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?;
// Validate configuration.
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
{
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else {
Some("expecting to use MPC-TLS")
};
if reject.is_some() {
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
// Run the TLS commitment protocol.
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request.
let verifier = verifier.verify().await?;
// Check that server identity is being proven.
if !verifier.request().server_identity() {
let verifier = verifier
.reject(Some("expecting to verify the server name"))
.await?;
verifier.close().await?;
return Err(anyhow::anyhow!("prover did not reveal the server name"));
}
// Check if predicates are requested and validate them.
let predicates = verifier.request().predicates();
if !predicates.is_empty() {
println!(
"Verifier: Prover requested {} predicate(s):",
predicates.len()
);
for pred in predicates {
println!(
" - '{}' on {:?} at {} indices",
pred.name(),
pred.direction(),
pred.indices().len()
);
}
}
// Define the predicate resolver - this maps predicate names to predicates.
// The resolver receives the predicate name and the indices from the prover's
// request.
let predicate_resolver = |name: &str, indices: &RangeSet<usize>| -> Option<Pred> {
match name {
JSON_STRING_PREDICATE => {
// Build the JSON string validation predicate with the provided indices
Some(build_json_string_predicate(indices))
}
_ => None,
}
};
// Accept with predicate verification.
let (
VerifierOutput {
server_name,
transcript,
..
},
verifier,
) = verifier
.accept_with_predicates(Some(&predicate_resolver))
.await?;
verifier.close().await?;
let server_name = server_name.expect("prover should have revealed server name");
let transcript = transcript.expect("prover should have revealed transcript data");
// Verify server name.
let ServerName::Dns(server_name) = server_name;
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
// The verifier can see the response but with the predicated string redacted.
let received = transcript.received_unsafe();
let redacted = String::from_utf8_lossy(received).replace('\0', "[REDACTED]");
println!("Verifier: Received data (string value redacted):\n{redacted}");
println!("Verifier: Predicate verified successfully!");
println!("Verifier: The hidden value is proven to be a valid JSON string");
Ok(())
}

View File

@@ -2,6 +2,7 @@ mod prover;
mod types;
mod verifier;
use anyhow::Result;
use prover::prover;
use std::{
env,
@@ -12,7 +13,7 @@ use tlsn_server_fixture_certs::SERVER_DOMAIN;
use verifier::verifier;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
@@ -25,7 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/elster");
let server_ip: IpAddr = server_host
.parse()
.map_err(|e| format!("Invalid IP address '{}': {}", server_host, e))?;
.map_err(|e| anyhow::anyhow!("Invalid IP address '{server_host}': {e}"))?;
let server_addr = SocketAddr::from((server_ip, server_port));
// Connect prover and verifier.

View File

@@ -4,6 +4,7 @@ use crate::types::received_commitments;
use super::types::ZKProofBundle;
use anyhow::Result;
use chrono::{Datelike, Local, NaiveDate};
use http_body_util::Empty;
use hyper::{body::Bytes, header, Request, StatusCode, Uri};
@@ -21,24 +22,27 @@ use spansy::{
http::{BodyContent, Requests, Responses},
Spanned,
};
use tls_server_fixture::CA_CERT_DER;
use tls_server_fixture::{CA_CERT_DER, SERVER_DOMAIN};
use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore},
config::{
prove::{ProveConfig, ProveConfigBuilder},
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig},
prover::Prover,
transcript::{
hash::{PlaintextHash, PlaintextHashSecret},
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
TranscriptSecret,
},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_examples::MAX_RECV_DATA;
use tokio::io::AsyncWriteExt;
use tlsn_examples::MAX_SENT_DATA;
use tokio::io::{AsyncRead, AsyncWrite};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
@@ -48,60 +52,64 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
mut verifier_extra_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<()> {
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err("URI must use HTTPS scheme".into());
return Err(anyhow::anyhow!("URI must use HTTPS scheme"));
}
let server_domain = uri.authority().ok_or("URI must have authority")?.host();
let server_domain = uri
.authority()
.ok_or_else(|| anyhow::anyhow!("URI must have authority"))?
.host();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into()?))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
// Create a new prover and perform necessary setup.
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
// Select the TLS commitment protocol.
.protocol(
MpcTlsConfig::builder()
// We must configure the amount of data we expect to exchange beforehand,
// which will be preprocessed prior to the
// connection. Reducing these limits will improve
// performance.
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?,
)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
verifier_socket.compat(),
)
.await?;
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Open a TCP connection to the server.
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?;
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Bind the prover to the server connection.
let (tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let tls_connection = TokioIo::new(tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
hyper::client::conn::http1::handshake(tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
@@ -118,7 +126,10 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let response = request_sender.send_request(request).await?;
if response.status() != StatusCode::OK {
return Err(format!("MPC-TLS request failed with status {}", response.status()).into());
return Err(anyhow::anyhow!(
"MPC-TLS request failed with status {}",
response.status()
));
}
// Create proof for the Verifier.
@@ -163,11 +174,11 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let received_commitments = received_commitments(&prover_output.transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("No received commitments found")?; // committed hash (of date of birth string)
.ok_or_else(|| anyhow::anyhow!("No received commitments found"))?; // committed hash (of date of birth string)
let received_secrets = received_secrets(&prover_output.transcript_secrets);
let received_secret = received_secrets
.first()
.ok_or("No received secrets found")?; // hash blinder
.ok_or_else(|| anyhow::anyhow!("No received secrets found"))?; // hash blinder
let proof_input = prepare_zk_proof_input(received, received_commitment, received_secret)?;
let proof_bundle = generate_zk_proof(&proof_input)?;
@@ -180,28 +191,30 @@ pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
}
// Reveal everything from the request, except for the authorization token.
fn reveal_request(
request: &[u8],
builder: &mut ProveConfigBuilder<'_>,
) -> Result<(), Box<dyn std::error::Error>> {
fn reveal_request(request: &[u8], builder: &mut ProveConfigBuilder<'_>) -> Result<()> {
let reqs = Requests::new_from_slice(request).collect::<Result<Vec<_>, _>>()?;
let req = reqs.first().ok_or("No requests found")?;
let req = reqs
.first()
.ok_or_else(|| anyhow::anyhow!("No requests found"))?;
if req.request.method.as_str() != "GET" {
return Err(format!("Expected GET method, found {}", req.request.method.as_str()).into());
return Err(anyhow::anyhow!(
"Expected GET method, found {}",
req.request.method.as_str()
));
}
let authorization_header = req
.headers_with_name(header::AUTHORIZATION.as_str())
.next()
.ok_or("Authorization header not found")?;
.ok_or_else(|| anyhow::anyhow!("Authorization header not found"))?;
let start_pos = authorization_header
.span()
.indices()
.min()
.ok_or("Could not find authorization header start position")?
.ok_or_else(|| anyhow::anyhow!("Could not find authorization header start position"))?
+ header::AUTHORIZATION.as_str().len()
+ 2;
let end_pos =
@@ -217,38 +230,43 @@ fn reveal_received(
received: &[u8],
builder: &mut ProveConfigBuilder<'_>,
transcript_commitment_builder: &mut TranscriptCommitConfigBuilder,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<()> {
let resp = Responses::new_from_slice(received).collect::<Result<Vec<_>, _>>()?;
let response = resp.first().ok_or("No responses found")?;
let body = response.body.as_ref().ok_or("Response body not found")?;
let response = resp
.first()
.ok_or_else(|| anyhow::anyhow!("No responses found"))?;
let body = response
.body
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Response body not found"))?;
let BodyContent::Json(json) = &body.content else {
return Err("Expected JSON body content".into());
return Err(anyhow::anyhow!("Expected JSON body content"));
};
// reveal tax year
let tax_year = json
.get("tax_year")
.ok_or("tax_year field not found in JSON")?;
.ok_or_else(|| anyhow::anyhow!("tax_year field not found in JSON"))?;
let start_pos = tax_year
.span()
.indices()
.min()
.ok_or("Could not find tax_year start position")?
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year start position"))?
- 11;
let end_pos = tax_year
.span()
.indices()
.max()
.ok_or("Could not find tax_year end position")?
.ok_or_else(|| anyhow::anyhow!("Could not find tax_year end position"))?
+ 1;
builder.reveal_recv(&(start_pos..end_pos))?;
// commit to hash of date of birth
let dob = json
.get("taxpayer.date_of_birth")
.ok_or("taxpayer.date_of_birth field not found in JSON")?;
.ok_or_else(|| anyhow::anyhow!("taxpayer.date_of_birth field not found in JSON"))?;
transcript_commitment_builder.commit_recv(dob.span())?;
@@ -279,7 +297,7 @@ fn prepare_zk_proof_input(
received: &[u8],
received_commitment: &PlaintextHash,
received_secret: &PlaintextHashSecret,
) -> Result<ZKProofInput, Box<dyn std::error::Error>> {
) -> Result<ZKProofInput> {
assert_eq!(received_commitment.direction, Direction::Received);
assert_eq!(received_commitment.hash.alg, HashAlgId::SHA256);
@@ -288,11 +306,11 @@ fn prepare_zk_proof_input(
let dob_start = received_commitment
.idx
.min()
.ok_or("No start index for DOB")?;
.ok_or_else(|| anyhow::anyhow!("No start index for DOB"))?;
let dob_end = received_commitment
.idx
.end()
.ok_or("No end index for DOB")?;
.ok_or_else(|| anyhow::anyhow!("No end index for DOB"))?;
let dob = received[dob_start..dob_end].to_vec();
let blinder = received_secret.blinder.as_bytes().to_vec();
let committed_hash = hash.value.as_bytes().to_vec();
@@ -306,8 +324,10 @@ fn prepare_zk_proof_input(
hasher.update(&blinder);
let computed_hash = hasher.finalize();
if committed_hash != computed_hash.as_slice() {
return Err("Computed hash does not match committed hash".into());
if committed_hash != computed_hash.as_ref() as &[u8] {
return Err(anyhow::anyhow!(
"Computed hash does not match committed hash"
));
}
Ok(ZKProofInput {
@@ -318,9 +338,7 @@ fn prepare_zk_proof_input(
})
}
fn generate_zk_proof(
proof_input: &ZKProofInput,
) -> Result<ZKProofBundle, Box<dyn std::error::Error>> {
fn generate_zk_proof(proof_input: &ZKProofInput) -> Result<ZKProofBundle> {
tracing::info!("🔒 Generating ZK proof with Noir...");
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
@@ -329,7 +347,7 @@ fn generate_zk_proof(
let json: Value = serde_json::from_str(PROGRAM_JSON)?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("bytecode field not found in program.json")?;
.ok_or_else(|| anyhow::anyhow!("bytecode field not found in program.json"))?;
let mut inputs: Vec<String> = vec![];
inputs.push(proof_input.proof_date.day().to_string());
@@ -354,16 +372,17 @@ fn generate_zk_proof(
tracing::debug!("Witness inputs {:?}", inputs);
let input_refs: Vec<&str> = inputs.iter().map(String::as_str).collect();
let witness = from_vec_str_to_witness_map(input_refs)?;
let witness = from_vec_str_to_witness_map(input_refs).map_err(|e| anyhow::anyhow!(e))?;
// Setup SRS
setup_srs_from_bytecode(bytecode, None, false)?;
setup_srs_from_bytecode(bytecode, None, false).map_err(|e| anyhow::anyhow!(e))?;
// Verification key
let vk = get_ultra_honk_verification_key(bytecode, false)?;
let vk = get_ultra_honk_verification_key(bytecode, false).map_err(|e| anyhow::anyhow!(e))?;
// Generate proof
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)?;
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)
.map_err(|e| anyhow::anyhow!(e))?;
tracing::info!("✅ Proof generated ({} bytes)", proof.len());
let proof_bundle = ZKProofBundle { vk, proof };

View File

@@ -1,16 +1,18 @@
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use anyhow::Result;
use chrono::{Local, NaiveDate};
use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_honk};
use serde_json::Value;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
config::{tls_commit::TlsCommitProtocolConfig, verifier::VerifierConfig},
connection::ServerName,
hash::HashAlgId,
transcript::{Direction, PartialTranscript},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tlsn_server_fixture_certs::SERVER_DOMAIN;
@@ -22,56 +24,91 @@ use tracing::instrument;
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript, Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?;
) -> Result<PartialTranscript> {
let verifier = Verifier::new(
VerifierConfig::builder()
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
);
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()?;
// Validate the proposed configuration and then run the TLS commitment protocol.
let verifier = verifier.commit(socket.compat()).await?;
let verifier = Verifier::new(verifier_config);
// This is the opportunity to ensure the prover does not attempt to overload the
// verifier.
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
{
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
Some("max_sent_data is too large")
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
Some("max_recv_data is too large")
} else {
None
}
} else {
Some("expecting to use MPC-TLS")
};
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await?;
if reject.is_some() {
verifier.reject(reject).await?;
return Err(anyhow::anyhow!("protocol configuration rejected"));
}
let server_name = server_name.ok_or("Prover should have revealed server name")?;
let transcript = transcript.ok_or("Prover should have revealed transcript data")?;
// Runs the TLS commitment protocol to completion.
let verifier = verifier.accept().await?.run().await?;
// Validate the proving request and then verify.
let verifier = verifier.verify().await?;
let request = verifier.request();
if !request.server_identity() || request.reveal().is_none() {
let verifier = verifier
.reject(Some(
"expecting to verify the server name and transcript data",
))
.await?;
verifier.close().await?;
return Err(anyhow::anyhow!(
"prover did not reveal the server name and transcript data"
));
}
let (
VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
},
verifier,
) = verifier.accept().await?;
verifier.close().await?;
let server_name = server_name.expect("server name should be present");
let transcript = transcript.expect("transcript should be present");
// Create hash commitment for the date of birth field from the response
let sent = transcript.sent_unsafe().to_vec();
let sent_data = String::from_utf8(sent.clone())
.map_err(|e| format!("Verifier expected valid UTF-8 sent data: {}", e))?;
.map_err(|e| anyhow::anyhow!("Verifier expected valid UTF-8 sent data: {e}"))?;
if !sent_data.contains(SERVER_DOMAIN) {
return Err(format!(
"Verification failed: Expected host {} not found in sent data",
SERVER_DOMAIN
)
.into());
return Err(anyhow::anyhow!(
"Verification failed: Expected host {SERVER_DOMAIN} not found in sent data"
));
}
// Check received data.
let received_commitments = received_commitments(&transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("Missing received hash commitment")?;
.ok_or_else(|| anyhow::anyhow!("Missing hash commitment"))?;
assert!(received_commitment.direction == Direction::Received);
assert!(received_commitment.hash.alg == HashAlgId::SHA256);
@@ -81,12 +118,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
if server_name.as_str() != SERVER_DOMAIN {
return Err(format!(
"Server name mismatch: expected {}, got {}",
SERVER_DOMAIN,
return Err(anyhow::anyhow!(
"Server name mismatch: expected {SERVER_DOMAIN}, got {}",
server_name.as_str()
)
.into());
));
}
// Receive ZKProof information from prover
@@ -94,26 +129,28 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
extra_socket.read_to_end(&mut buf).await?;
if buf.is_empty() {
return Err("No ZK proof data received from prover".into());
return Err(anyhow::anyhow!("No ZK proof data received from prover"));
}
let msg: ZKProofBundle = bincode::deserialize(&buf)
.map_err(|e| format!("Failed to deserialize ZK proof bundle: {}", e))?;
.map_err(|e| anyhow::anyhow!("Failed to deserialize ZK proof bundle: {e}"))?;
// Verify zk proof
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
let json: Value = serde_json::from_str(PROGRAM_JSON)
.map_err(|e| format!("Failed to parse Noir circuit: {}", e))?;
.map_err(|e| anyhow::anyhow!("Failed to parse Noir circuit: {e}"))?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("Bytecode field missing in noir.json")?;
.ok_or_else(|| anyhow::anyhow!("Bytecode field missing in noir.json"))?;
let vk = get_ultra_honk_verification_key(bytecode, false)
.map_err(|e| format!("Failed to get verification key: {}", e))?;
.map_err(|e| anyhow::anyhow!("Failed to get verification key: {e}"))?;
if vk != msg.vk {
return Err("Verification key mismatch between computed and provided by prover".into());
return Err(anyhow::anyhow!(
"Verification key mismatch between computed and provided by prover"
));
}
let proof = msg.proof.clone();
@@ -125,12 +162,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// * and 32*32 bytes for the hash
let min_bytes = (32 + 3) * 32;
if proof.len() < min_bytes {
return Err(format!(
"Proof too short: expected at least {} bytes, got {}",
min_bytes,
return Err(anyhow::anyhow!(
"Proof too short: expected at least {min_bytes} bytes, got {}",
proof.len()
)
.into());
));
}
// Check that the proof date is correctly included in the proof
@@ -139,14 +174,12 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
let proof_date_year: i32 = i32::from_be_bytes(proof[92..96].try_into()?);
let proof_date_from_proof =
NaiveDate::from_ymd_opt(proof_date_year, proof_date_month, proof_date_day)
.ok_or("Invalid proof date in proof")?;
.ok_or_else(|| anyhow::anyhow!("Invalid proof date in proof"))?;
let today = Local::now().date_naive();
if (today - proof_date_from_proof).num_days() < 0 {
return Err(format!(
"The proof date can only be today or in the past: provided {}, today {}",
proof_date_from_proof, today
)
.into());
return Err(anyhow::anyhow!(
"The proof date can only be today or in the past: provided {proof_date_from_proof}, today {today}"
));
}
// Check that the committed hash in the proof matches the hash from the
@@ -164,7 +197,9 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
hex::encode(&committed_hash_in_proof),
hex::encode(&expected_hash)
);
return Err("Hash in proof does not match committed hash in MPC-TLS".into());
return Err(anyhow::anyhow!(
"Hash in proof does not match committed hash in MPC-TLS"
));
}
tracing::info!(
"✅ The hash in the proof matches the committed hash in MPC-TLS ({})",
@@ -173,10 +208,10 @@ pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>
// Finally verify the proof
let is_valid = verify_ultra_honk(msg.proof, msg.vk)
.map_err(|e| format!("ZKProof Verification failed: {}", e))?;
.map_err(|e| anyhow::anyhow!("ZKProof Verification failed: {e}"))?;
if !is_valid {
tracing::error!("❌ Age verification ZKProof failed to verify");
return Err("Age verification ZKProof failed to verify".into());
return Err(anyhow::anyhow!("Age verification ZKProof failed to verify"));
}
tracing::info!("✅ Age verification ZKProof successfully verified");

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-formats"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]

View File

@@ -9,6 +9,7 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench {
group: None,
@@ -20,6 +21,7 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096,
defer_decryption: true,
memory_profile: false,
reveal_all: true,
};
#[derive(Deserialize)]
@@ -79,6 +81,8 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -97,6 +101,8 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
impl BenchItem {
@@ -132,6 +138,10 @@ impl BenchItem {
if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile;
}
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
}
pub fn into_bench(&self) -> Bench {
@@ -145,6 +155,7 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
}
}
}
@@ -164,6 +175,8 @@ pub struct Bench {
pub defer_decryption: bool,
#[serde(rename = "memory-profile")]
pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -5,9 +5,15 @@ use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
use harness_core::bench::{Bench, ProverMetrics};
use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore},
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
},
connection::ServerName,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
prover::Prover,
webpki::{CertificateDer, RootCertStore},
};
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -22,41 +28,47 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let sent = verifier_io.sent();
let recv = verifier_io.recv();
let mut builder = ProtocolConfig::builder();
builder.max_sent_data(config.upload_size);
builder.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder.max_recv_data(config.download_size + RECV_PADDING);
let protocol_config = builder.build()?;
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
let prover = Prover::new(
ProverConfig::builder()
.tls_config(tls_config)
.protocol_config(protocol_config)
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.build()?,
);
let prover = Prover::new(ProverConfig::builder().build()?);
let time_start = web_time::Instant::now();
let prover = prover.setup(verifier_io).await?;
let prover = prover
.commit(
TlsCommitConfig::builder()
.protocol({
let mut builder = MpcTlsConfig::builder()
.max_sent_data(config.upload_size)
.defer_decryption_from_start(config.defer_decryption);
if !config.defer_decryption {
builder = builder.max_recv_data_online(config.download_size + RECV_PADDING);
}
builder
.max_recv_data(config.download_size + RECV_PADDING)
.build()
}?)
.build()?,
verifier_io,
)
.await?;
let time_preprocess = time_start.elapsed().as_millis();
let time_start_online = web_time::Instant::now();
let uploaded_preprocess = sent.load(Ordering::Relaxed);
let downloaded_preprocess = recv.load(Ordering::Relaxed);
let (mut conn, prover_fut) = prover.connect(provider.provide_server_io().await?).await?;
let (mut conn, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
provider.provide_server_io().await?,
)
.await?;
let (_, mut prover) = futures::try_join!(
async {
@@ -86,14 +98,27 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder
.server_identity()
.reveal_sent(&(0..sent_len))?
.reveal_recv(&(0..recv_len))?;
.reveal_sent(&reveal_sent_range)?
.reveal_recv(&reveal_recv_range)?;
let config = builder.build()?;
let prove_config = builder.build()?;
prover.prove(&config).await?;
prover.prove(&prove_config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -2,33 +2,31 @@ use anyhow::Result;
use harness_core::bench::Bench;
use tlsn::{
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
verifier::{Verifier, VerifierConfig, VerifyConfig},
config::verifier::VerifierConfig,
verifier::Verifier,
webpki::{CertificateDer, RootCertStore},
};
use tlsn_server_fixture_certs::CA_CERT_DER;
use crate::{IoProvider, bench::RECV_PADDING};
pub async fn bench_verifier(provider: &IoProvider, config: &Bench) -> Result<()> {
let mut builder = ProtocolConfigValidator::builder();
builder
.max_sent_data(config.upload_size)
.max_recv_data(config.download_size + RECV_PADDING);
let protocol_config = builder.build()?;
use crate::IoProvider;
pub async fn bench_verifier(provider: &IoProvider, _config: &Bench) -> Result<()> {
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(protocol_config)
.build()?,
);
let verifier = verifier.setup(provider.provide_proto_io().await?).await?;
let mut verifier = verifier.run().await?;
verifier.verify(&VerifyConfig::default()).await?;
let verifier = verifier
.commit(provider.provide_proto_io().await?)
.await?
.accept()
.await?
.run()
.await?;
let (_, verifier) = verifier.verify().await?.accept().await?;
verifier.close().await?;
Ok(())

View File

@@ -1,10 +1,17 @@
use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
verifier::VerifierConfig,
},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
prover::Prover,
transcript::{TranscriptCommitConfig, TranscriptCommitment, TranscriptCommitmentKind},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -21,35 +28,35 @@ const MAX_RECV_DATA: usize = 1 << 11;
crate::test!("basic", prover, verifier);
async fn prover(provider: &IoProvider) {
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.defer_decryption_from_start(true)
.build()
.unwrap(),
)
.build()
.unwrap(),
)
.setup(provider.provide_proto_io().await.unwrap())
.await
.unwrap();
let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit(
TlsCommitConfig::builder()
.protocol(
MpcTlsConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.defer_decryption_from_start(true)
.build()
.unwrap(),
)
.build()
.unwrap(),
provider.provide_proto_io().await.unwrap(),
)
.await
.unwrap();
let (tls_connection, prover_fut) = prover
.connect(provider.provide_server_io().await.unwrap())
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap(),
provider.provide_server_io().await.unwrap(),
)
.await
.unwrap();
@@ -113,33 +120,34 @@ async fn prover(provider: &IoProvider) {
async fn verifier(provider: &IoProvider) {
let config = VerifierConfig::builder()
.protocol_config_validator(
ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap(),
)
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap();
let verifier = Verifier::new(config);
let VerifierOutput {
server_name,
transcript_commitments,
..
} = verifier
.verify(
provider.provide_proto_io().await.unwrap(),
&VerifyConfig::default(),
)
let verifier = Verifier::new(config)
.commit(provider.provide_proto_io().await.unwrap())
.await
.unwrap()
.accept()
.await
.unwrap()
.run()
.await
.unwrap();
let (
VerifierOutput {
server_name,
transcript_commitments,
..
},
verifier,
) = verifier.verify().await.unwrap().accept().await.unwrap();
verifier.close().await.unwrap();
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);

View File

@@ -32,18 +32,13 @@ use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
pub enum Target {
#[default]
Native,
Browser,
}
impl Default for Target {
fn default() -> Self {
Self::Native
}
}
struct Runner {
network: Network,
server_fixture: ServerFixture,

View File

@@ -5,7 +5,7 @@ description = "TLSNotary MPC-TLS protocol"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]
@@ -33,7 +33,6 @@ mpz-ole = { workspace = true }
mpz-share-conversion = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-memory-core = { workspace = true }
mpz-circuits = { workspace = true }
ludi = { git = "https://github.com/sinui0/ludi", rev = "e511c3b", default-features = false }
serio = { workspace = true }
@@ -57,9 +56,9 @@ pin-project-lite = { workspace = true }
web-time = { workspace = true }
[dev-dependencies]
mpz-ole = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true }
mpz-garble = { workspace = true }
mpz-common = { workspace = true, features = ["test-utils"] }
mpz-ot = { workspace = true, features = ["ideal"] }
mpz-ideal-vm = { workspace = true }
cipher-crate = { package = "cipher", version = "0.4" }
generic-array = { workspace = true }

View File

@@ -487,7 +487,7 @@ impl RecordLayer {
sent_records.push(Record {
seq: op.seq,
typ: op.typ,
typ: op.typ.into(),
plaintext: op.plaintext,
explicit_nonce: op.explicit_nonce,
ciphertext,
@@ -505,7 +505,7 @@ impl RecordLayer {
recv_records.push(Record {
seq: op.seq,
typ: op.typ,
typ: op.typ.into(),
plaintext,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,
@@ -578,7 +578,7 @@ impl RecordLayer {
recv_records.push(Record {
seq: op.seq,
typ: op.typ,
typ: op.typ.into(),
plaintext,
explicit_nonce: op.explicit_nonce,
ciphertext: op.ciphertext,

View File

@@ -456,9 +456,8 @@ mod tests {
};
use mpz_common::context::test_st_context;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_memory_core::{binary::U8, correlated::Delta};
use mpz_ot::ideal::cot::ideal_cot;
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::binary::U8;
use mpz_share_conversion::ideal::ideal_share_convert;
use rand::{rngs::StdRng, SeedableRng};
use rstest::*;
@@ -574,13 +573,8 @@ mod tests {
}
fn create_vm(key: [u8; 16], iv: [u8; 4]) -> ((impl Vm<Binary>, Vars), (impl Vm<Binary>, Vars)) {
let mut rng = StdRng::seed_from_u64(0);
let block = Block::random(&mut rng);
let (sender, receiver) = ideal_cot(block);
let delta = Delta::new(block);
let mut vm_0 = Garbler::new(sender, [0u8; 16], delta);
let mut vm_1 = Evaluator::new(receiver);
let mut vm_0 = IdealVm::new();
let mut vm_1 = IdealVm::new();
let key_ref_0 = vm_0.alloc::<Array<U8, 16>>().unwrap();
vm_0.mark_public(key_ref_0).unwrap();

View File

@@ -4,14 +4,13 @@ use futures::{AsyncReadExt, AsyncWriteExt};
use mpc_tls::{Config, MpcTlsFollower, MpcTlsLeader};
use mpz_common::context::test_mt_context;
use mpz_core::Block;
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_ideal_vm::IdealVm;
use mpz_memory_core::correlated::Delta;
use mpz_ot::{
cot::{DerandCOTReceiver, DerandCOTSender},
ideal::rcot::ideal_rcot,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand::{rngs::StdRng, SeedableRng};
use rustls_pki_types::CertificateDer;
use tls_client::RootCertStore;
use tls_client_async::bind_client;
@@ -23,7 +22,6 @@ use webpki::anchor_from_trusted_cert;
const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "expensive"]
async fn mpc_tls_test() {
tracing_subscriber::fmt::init();
@@ -139,14 +137,8 @@ fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) {
let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a);
let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b);
let mpc_a = Arc::new(Mutex::new(Garbler::new(
DerandCOTSender::new(rcot_send_a.clone()),
rand::rng().random(),
delta_a,
)));
let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new(
rcot_recv_b.clone(),
))));
let mpc_a = Arc::new(Mutex::new(IdealVm::new()));
let mpc_b = Arc::new(Mutex::new(IdealVm::new()));
let leader = MpcTlsLeader::new(
config.clone(),

View File

@@ -5,7 +5,7 @@ description = "A TLS backend trait for TLSNotary"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "A TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "sync"]
categories = ["cryptography"]
license = "Apache-2.0 OR ISC OR MIT"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
autobenches = false

View File

@@ -1,5 +1,6 @@
use super::{Backend, BackendError};
use crate::{DecryptMode, EncryptMode, Error};
#[allow(deprecated)]
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead, Payload},
Aes128Gcm,
@@ -507,6 +508,7 @@ impl Encrypter {
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap();
// ciphertext will have the MAC appended
@@ -568,6 +570,7 @@ impl Decrypter {
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(&m.payload.0[0..8]);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let plaintext = cipher
.decrypt(nonce, aes_payload)

View File

@@ -5,7 +5,7 @@ description = "Cryptographic operations for the TLSNotary TLS client"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "Apache-2.0 OR ISC OR MIT"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
[lints]

View File

@@ -4,7 +4,7 @@ authors = ["TLSNotary Team"]
keywords = ["tls", "mpc", "2pc", "prover"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2024"
[lints]
@@ -12,6 +12,7 @@ workspace = true
[features]
default = ["rayon"]
mozilla-certs = ["tlsn-core/mozilla-certs"]
rayon = ["mpz-zk/rayon", "mpz-garble/rayon"]
web = ["dep:web-spawn"]
@@ -29,9 +30,10 @@ serio = { workspace = true, features = ["compat"] }
uid-mux = { workspace = true, features = ["serio"] }
web-spawn = { workspace = true, optional = true }
mpz-circuits = { workspace = true, features = ["aes"] }
mpz-common = { workspace = true }
mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-predicate = { workspace = true }
mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true }
mpz-hash = { workspace = true }
@@ -40,10 +42,10 @@ mpz-ole = { workspace = true }
mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
mpz-ideal-vm = { workspace = true }
aes = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }
rand = { workspace = true }

3
crates/tlsn/build.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
println!("cargo:rustc-check-cfg=cfg(tlsn_insecure)");
}

View File

@@ -1,368 +0,0 @@
//! TLSNotary protocol config and config utilities.
use core::fmt;
use once_cell::sync::Lazy;
use semver::Version;
use serde::{Deserialize, Serialize};
use std::error::Error;
pub use tlsn_core::webpki::{CertificateDer, PrivateKeyDer, RootCertStore};
// Default is 32 bytes to decrypt the TLS protocol messages.
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
// Default maximum number of TLS records to allow.
//
// This would allow for up to 50Mb upload from prover to verifier.
const DEFAULT_RECORDS_LIMIT: usize = 256;
// Current version that is running.
static VERSION: Lazy<Version> = Lazy::new(|| {
Version::parse(env!("CARGO_PKG_VERSION"))
.map_err(|err| ProtocolConfigError::new(ErrorKind::Version, err))
.unwrap()
});
/// Protocol configuration to be set up initially by prover and verifier.
#[derive(derive_builder::Builder, Clone, Debug, Deserialize, Serialize)]
#[builder(build_fn(validate = "Self::validate"))]
pub struct ProtocolConfig {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
#[builder(setter(strip_option), default)]
max_sent_records: Option<usize>,
/// Maximum number of bytes that can be decrypted online, i.e. while the
/// MPC-TLS connection is active.
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
max_recv_data_online: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of received application data records that can be
/// decrypted online, i.e. while the MPC-TLS connection is active.
#[builder(setter(strip_option), default)]
max_recv_records_online: Option<usize>,
/// Whether the `deferred decryption` feature is toggled on from the start
/// of the MPC-TLS connection.
#[builder(default = "true")]
defer_decryption_from_start: bool,
/// Network settings.
#[builder(default)]
network: NetworkSetting,
/// Version that is being run by prover/verifier.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
}
impl ProtocolConfigBuilder {
fn validate(&self) -> Result<(), String> {
if self.max_recv_data_online > self.max_recv_data {
return Err(
"max_recv_data_online must be smaller or equal to max_recv_data".to_string(),
);
}
Ok(())
}
}
impl ProtocolConfig {
/// Creates a new builder for `ProtocolConfig`.
pub fn builder() -> ProtocolConfigBuilder {
ProtocolConfigBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> Option<usize> {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be decrypted online.
pub fn max_recv_data_online(&self) -> usize {
self.max_recv_data_online
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of received application data records that
/// can be decrypted online.
pub fn max_recv_records_online(&self) -> Option<usize> {
self.max_recv_records_online
}
/// Returns whether the `deferred decryption` feature is toggled on from the
/// start of the MPC-TLS connection.
pub fn defer_decryption_from_start(&self) -> bool {
self.defer_decryption_from_start
}
/// Returns the network settings.
pub fn network(&self) -> NetworkSetting {
self.network
}
}
/// Protocol configuration validator used by checker (i.e. verifier) to perform
/// compatibility check with the peer's (i.e. the prover's) configuration.
#[derive(derive_builder::Builder, Clone, Debug, Serialize, Deserialize)]
pub struct ProtocolConfigValidator {
/// Maximum number of bytes that can be sent.
max_sent_data: usize,
/// Maximum number of application data records that can be sent.
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
max_sent_records: usize,
/// Maximum number of bytes that can be received.
max_recv_data: usize,
/// Maximum number of application data records that can be received online.
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
max_recv_records_online: usize,
/// Version that is being run by checker.
#[builder(setter(skip), default = "VERSION.clone()")]
version: Version,
}
impl ProtocolConfigValidator {
/// Creates a new builder for `ProtocolConfigValidator`.
pub fn builder() -> ProtocolConfigValidatorBuilder {
ProtocolConfigValidatorBuilder::default()
}
/// Returns the maximum number of bytes that can be sent.
pub fn max_sent_data(&self) -> usize {
self.max_sent_data
}
/// Returns the maximum number of application data records that can
/// be sent.
pub fn max_sent_records(&self) -> usize {
self.max_sent_records
}
/// Returns the maximum number of bytes that can be received.
pub fn max_recv_data(&self) -> usize {
self.max_recv_data
}
/// Returns the maximum number of application data records that can
/// be received online.
pub fn max_recv_records_online(&self) -> usize {
self.max_recv_records_online
}
/// Performs compatibility check of the protocol configuration between
/// prover and verifier.
pub fn validate(&self, config: &ProtocolConfig) -> Result<(), ProtocolConfigError> {
self.check_max_transcript_size(config.max_sent_data, config.max_recv_data)?;
self.check_max_records(config.max_sent_records, config.max_recv_records_online)?;
self.check_version(&config.version)?;
Ok(())
}
// Checks if both the sent and recv data are within limits.
fn check_max_transcript_size(
&self,
max_sent_data: usize,
max_recv_data: usize,
) -> Result<(), ProtocolConfigError> {
if max_sent_data > self.max_sent_data {
return Err(ProtocolConfigError::max_transcript_size(format!(
"max_sent_data {:?} is greater than the configured limit {:?}",
max_sent_data, self.max_sent_data,
)));
}
if max_recv_data > self.max_recv_data {
return Err(ProtocolConfigError::max_transcript_size(format!(
"max_recv_data {:?} is greater than the configured limit {:?}",
max_recv_data, self.max_recv_data,
)));
}
Ok(())
}
fn check_max_records(
&self,
max_sent_records: Option<usize>,
max_recv_records_online: Option<usize>,
) -> Result<(), ProtocolConfigError> {
if let Some(max_sent_records) = max_sent_records
&& max_sent_records > self.max_sent_records
{
return Err(ProtocolConfigError::max_record_count(format!(
"max_sent_records {} is greater than the configured limit {}",
max_sent_records, self.max_sent_records,
)));
}
if let Some(max_recv_records_online) = max_recv_records_online
&& max_recv_records_online > self.max_recv_records_online
{
return Err(ProtocolConfigError::max_record_count(format!(
"max_recv_records_online {} is greater than the configured limit {}",
max_recv_records_online, self.max_recv_records_online,
)));
}
Ok(())
}
// Checks if both versions are the same (might support check for different but
// compatible versions in the future).
fn check_version(&self, peer_version: &Version) -> Result<(), ProtocolConfigError> {
if *peer_version != self.version {
return Err(ProtocolConfigError::version(format!(
"prover's version {:?} is different from verifier's version {:?}",
peer_version, self.version
)));
}
Ok(())
}
}
/// Settings for the network environment.
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
Latency,
}
impl Default for NetworkSetting {
fn default() -> Self {
Self::Latency
}
}
/// A ProtocolConfig error.
#[derive(thiserror::Error, Debug)]
pub struct ProtocolConfigError {
kind: ErrorKind,
#[source]
source: Option<Box<dyn Error + Send + Sync>>,
}
impl ProtocolConfigError {
fn new<E>(kind: ErrorKind, source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Self {
kind,
source: Some(source.into()),
}
}
fn max_transcript_size(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::MaxTranscriptSize,
source: Some(msg.into().into()),
}
}
fn max_record_count(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::MaxRecordCount,
source: Some(msg.into().into()),
}
}
fn version(msg: impl Into<String>) -> Self {
Self {
kind: ErrorKind::Version,
source: Some(msg.into().into()),
}
}
}
impl fmt::Display for ProtocolConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
ErrorKind::MaxTranscriptSize => write!(f, "max transcript size exceeded")?,
ErrorKind::MaxRecordCount => write!(f, "max record count exceeded")?,
ErrorKind::Version => write!(f, "version error")?,
}
if let Some(ref source) = self.source {
write!(f, " caused by: {source}")?;
}
Ok(())
}
}
#[derive(Debug)]
enum ErrorKind {
MaxTranscriptSize,
MaxRecordCount,
Version,
}
#[cfg(test)]
mod test {
use super::*;
use rstest::{fixture, rstest};
const TEST_MAX_SENT_LIMIT: usize = 1 << 12;
const TEST_MAX_RECV_LIMIT: usize = 1 << 14;
#[fixture]
#[once]
fn config_validator() -> ProtocolConfigValidator {
ProtocolConfigValidator::builder()
.max_sent_data(TEST_MAX_SENT_LIMIT)
.max_recv_data(TEST_MAX_RECV_LIMIT)
.build()
.unwrap()
}
#[rstest]
#[case::same_max_sent_recv_data(TEST_MAX_SENT_LIMIT, TEST_MAX_RECV_LIMIT)]
#[case::smaller_max_sent_data(1 << 11, TEST_MAX_RECV_LIMIT)]
#[case::smaller_max_recv_data(TEST_MAX_SENT_LIMIT, 1 << 13)]
#[case::smaller_max_sent_recv_data(1 << 7, 1 << 9)]
fn test_check_success(
config_validator: &ProtocolConfigValidator,
#[case] max_sent_data: usize,
#[case] max_recv_data: usize,
) {
let peer_config = ProtocolConfig::builder()
.max_sent_data(max_sent_data)
.max_recv_data(max_recv_data)
.build()
.unwrap();
assert!(config_validator.validate(&peer_config).is_ok())
}
#[rstest]
#[case::bigger_max_sent_data(1 << 13, TEST_MAX_RECV_LIMIT)]
#[case::bigger_max_recv_data(1 << 10, 1 << 16)]
#[case::bigger_max_sent_recv_data(1 << 14, 1 << 21)]
fn test_check_fail(
config_validator: &ProtocolConfigValidator,
#[case] max_sent_data: usize,
#[case] max_recv_data: usize,
) {
let peer_config = ProtocolConfig::builder()
.max_sent_data(max_sent_data)
.max_recv_data(max_recv_data)
.build()
.unwrap();
assert!(config_validator.validate(&peer_config).is_err())
}
}

View File

@@ -4,10 +4,11 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub mod config;
pub(crate) mod context;
pub(crate) mod ghash;
pub(crate) mod map;
pub(crate) mod mpz;
pub(crate) mod msg;
pub(crate) mod mux;
pub mod prover;
pub(crate) mod tag;
@@ -15,7 +16,16 @@ pub(crate) mod transcript_internal;
pub mod verifier;
pub use tlsn_attestation as attestation;
pub use tlsn_core::{connection, hash, transcript};
pub use tlsn_core::{config, connection, hash, transcript, webpki};
use std::sync::LazyLock;
use semver::Version;
// Package version.
pub(crate) static VERSION: LazyLock<Version> = LazyLock::new(|| {
Version::parse(env!("CARGO_PKG_VERSION")).expect("cargo pkg version should be a valid semver")
});
/// The party's role in the TLSN protocol.
///

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet;
use rangeset::set::RangeSet;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
@@ -77,7 +77,7 @@ where
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter_ranges() {
for idx in idx.iter() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,

233
crates/tlsn/src/mpz.rs Normal file
View File

@@ -0,0 +1,233 @@
use std::sync::Arc;
use mpc_tls::{MpcTlsFollower, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
#[cfg(not(tlsn_insecure))]
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
use mpz_garble_core::Delta;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Key, Mac},
};
#[cfg(not(tlsn_insecure))]
use mpz_ot::cot::{DerandCOTReceiver, DerandCOTSender};
use mpz_ot::{
chou_orlandi as co, ferret, kos,
rcot::shared::{SharedRCOTReceiver, SharedRCOTSender},
};
use mpz_zk::{Prover, Verifier};
#[cfg(not(tlsn_insecure))]
use rand::Rng;
use tlsn_core::config::tls_commit::mpc::{MpcTlsConfig, NetworkSetting};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::transcript_internal::commit::encoding::{KeyStore, MacStore};
#[cfg(not(tlsn_insecure))]
pub(crate) type ProverMpc =
Garbler<DerandCOTSender<SharedRCOTSender<kos::Sender<co::Receiver>, Block>>>;
#[cfg(tlsn_insecure)]
pub(crate) type ProverMpc = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type ProverZk =
Prover<SharedRCOTReceiver<ferret::Receiver<kos::Receiver<co::Sender>>, bool, Block>>;
#[cfg(tlsn_insecure)]
pub(crate) type ProverZk = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type VerifierMpc =
Evaluator<DerandCOTReceiver<SharedRCOTReceiver<kos::Receiver<co::Sender>, bool, Block>>>;
#[cfg(tlsn_insecure)]
pub(crate) type VerifierMpc = mpz_ideal_vm::IdealVm;
#[cfg(not(tlsn_insecure))]
pub(crate) type VerifierZk =
Verifier<SharedRCOTSender<ferret::Sender<kos::Sender<co::Receiver>>, Block>>;
#[cfg(tlsn_insecure)]
pub(crate) type VerifierZk = mpz_ideal_vm::IdealVm;
pub(crate) struct ProverDeps {
pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
pub(crate) mpc_tls: MpcTlsLeader,
}
pub(crate) fn build_prover_deps(config: MpcTlsConfig, ctx: Context) -> ProverDeps {
let mut rng = rand::rng();
let delta = Delta::new(Block::random(&mut rng));
let base_ot_send = co::Sender::default();
let base_ot_recv = co::Receiver::default();
let rcot_send = kos::Sender::new(
kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
let rcot_recv = ferret::Receiver::new(
ferret::FerretConfig::builder()
.lpn_type(ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_recv,
);
let rcot_send = SharedRCOTSender::new(rcot_send);
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
#[cfg(not(tlsn_insecure))]
let mpc = ProverMpc::new(DerandCOTSender::new(rcot_send.clone()), rng.random(), delta);
#[cfg(tlsn_insecure)]
let mpc = mpz_ideal_vm::IdealVm::new();
#[cfg(not(tlsn_insecure))]
let zk = ProverZk::new(Default::default(), rcot_recv.clone());
#[cfg(tlsn_insecure)]
let zk = mpz_ideal_vm::IdealVm::new();
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
let mpc_tls = MpcTlsLeader::new(
build_mpc_tls_config(config),
ctx,
vm.clone(),
(rcot_send.clone(), rcot_send.clone(), rcot_send),
rcot_recv,
);
ProverDeps { vm, mpc_tls }
}
pub(crate) struct VerifierDeps {
pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
pub(crate) mpc_tls: MpcTlsFollower,
}
pub(crate) fn build_verifier_deps(config: MpcTlsConfig, ctx: Context) -> VerifierDeps {
let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = co::Sender::default();
let base_ot_recv = co::Receiver::default();
let rcot_send = kos::Sender::new(
kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_send = ferret::Sender::new(
ferret::FerretConfig::builder()
.lpn_type(ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_send,
);
let rcot_recv = kos::Receiver::new(kos::ReceiverConfig::default(), base_ot_send);
let rcot_send = SharedRCOTSender::new(rcot_send);
let rcot_recv = SharedRCOTReceiver::new(rcot_recv);
#[cfg(not(tlsn_insecure))]
let mpc = VerifierMpc::new(DerandCOTReceiver::new(rcot_recv.clone()));
#[cfg(tlsn_insecure)]
let mpc = mpz_ideal_vm::IdealVm::new();
#[cfg(not(tlsn_insecure))]
let zk = VerifierZk::new(Default::default(), delta, rcot_send.clone());
#[cfg(tlsn_insecure)]
let zk = mpz_ideal_vm::IdealVm::new();
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));
let mpc_tls = MpcTlsFollower::new(
build_mpc_tls_config(config),
ctx,
vm.clone(),
rcot_send,
(rcot_recv.clone(), rcot_recv.clone(), rcot_recv),
);
VerifierDeps { vm, mpc_tls }
}
fn build_mpc_tls_config(config: MpcTlsConfig) -> mpc_tls::Config {
let mut builder = mpc_tls::Config::builder();
builder
.defer_decryption(config.defer_decryption_from_start())
.max_sent(config.max_sent_data())
.max_recv_online(config.max_recv_data_online())
.max_recv(config.max_recv_data());
if let Some(max_sent_records) = config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
pub(crate) fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) {
keys.client_write_key = vm
.translate(keys.client_write_key)
.expect("VM memory should be consistent");
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.expect("VM memory should be consistent");
keys.server_write_key = vm
.translate(keys.server_write_key)
.expect("VM memory should be consistent");
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.expect("VM memory should be consistent");
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.expect("VM memory should be consistent");
}
impl<T> KeyStore for Verifier<T> {
fn delta(&self) -> &Delta {
self.delta()
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
impl<T> MacStore for Prover<T> {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[cfg(tlsn_insecure)]
mod insecure {
use super::*;
use mpz_ideal_vm::IdealVm;
impl KeyStore for IdealVm {
fn delta(&self) -> &Delta {
unimplemented!("encodings not supported in insecure mode")
}
fn get_keys(&self, _data: Vector<U8>) -> Option<&[Key]> {
unimplemented!("encodings not supported in insecure mode")
}
}
impl MacStore for IdealVm {
fn get_macs(&self, _data: Vector<U8>) -> Option<&[Mac]> {
unimplemented!("encodings not supported in insecure mode")
}
}
}

51
crates/tlsn/src/msg.rs Normal file
View File

@@ -0,0 +1,51 @@
use semver::Version;
use serde::{Deserialize, Serialize};
use tlsn_core::{
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
connection::{HandshakeData, ServerName},
transcript::PartialTranscript,
};
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct TlsCommitRequestMsg {
pub(crate) request: TlsCommitRequest,
pub(crate) version: Version,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ProveRequestMsg {
pub(crate) request: ProveRequest,
pub(crate) handshake: Option<(ServerName, HandshakeData)>,
pub(crate) transcript: Option<PartialTranscript>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Response {
pub(crate) result: Result<(), RejectionReason>,
}
impl Response {
pub(crate) fn ok() -> Self {
Self { result: Ok(()) }
}
pub(crate) fn err(msg: Option<impl Into<String>>) -> Self {
Self {
result: Err(RejectionReason(msg.map(Into::into))),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct RejectionReason(Option<String>);
impl From<RejectionReason> for crate::prover::ProverError {
fn from(value: RejectionReason) -> Self {
if let Some(msg) = value.0 {
crate::prover::ProverError::config(format!("verifier rejected with reason: {msg}"))
} else {
crate::prover::ProverError::config("verifier rejected without providing a reason")
}
}
}

View File

@@ -1,55 +1,45 @@
//! Prover.
mod config;
mod error;
mod future;
mod prove;
pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError;
pub use future::ProverFuture;
use rustls_pki_types::CertificateDer;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
pub use tlsn_core::ProverOutput;
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use webpki::anchor_from_trusted_cert;
use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
use crate::{
Role,
context::build_mt_context,
mpz::{ProverDeps, build_prover_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use serio::SinkExt;
use mpc_tls::LeaderCtrl;
use mpz_vm_core::prelude::*;
use rustls_pki_types::CertificateDer;
use serio::{SinkExt, stream::IoStreamExt};
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
connection::ServerName,
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, TlsCommitProtocolConfig},
},
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use webpki::anchor_from_trusted_cert;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
>;
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
mpz_ot::ferret::Receiver<mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>>,
bool,
mpz_core::Block,
>;
pub(crate) type Mpc =
mpz_garble::protocol::semihonest::Garbler<mpz_ot::cot::DerandCOTSender<RCOTSender>>;
pub(crate) type Zk = mpz_zk::Prover<RCOTReceiver>;
/// A prover instance.
#[derive(Debug)]
pub struct Prover<T: state::ProverState = state::Initialized> {
@@ -73,34 +63,53 @@ impl Prover<state::Initialized> {
}
}
/// Sets up the prover.
/// Starts the TLS commitment protocol.
///
/// This performs all MPC setup prior to establishing the connection to the
/// application server.
/// This initiates the TLS commitment protocol, including performing any
/// necessary preprocessing operations.
///
/// # Arguments
///
/// * `config` - The TLS commitment configuration.
/// * `socket` - The socket to the TLS verifier.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsCommitConfig,
socket: S,
) -> Result<Prover<state::Setup>, ProverError> {
) -> Result<Prover<state::CommitAccepted>, ProverError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Sends protocol configuration to verifier for compatibility check.
mux_fut
.poll_with(ctx.io_mut().send(self.config.protocol_config().clone()))
.poll_with(async {
ctx.io_mut()
.send(TlsCommitRequestMsg {
request: config.to_request(),
version: crate::VERSION.clone(),
})
.await?;
ctx.io_mut()
.expect_next::<Response>()
.await?
.result
.map_err(ProverError::from)
})
.await?;
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx);
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = config.protocol().clone() else {
unreachable!("only MPC TLS is supported");
};
let ProverDeps { vm, mut mpc_tls } = build_prover_deps(mpc_tls_config, ctx);
// Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
translate_keys(&mut keys, &vm_lock);
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -112,7 +121,7 @@ impl Prover<state::Initialized> {
Ok(Prover {
config: self.config,
span: self.span,
state: state::Setup {
state: state::CommitAccepted {
mux_ctrl,
mux_fut,
mpc_tls,
@@ -123,21 +132,24 @@ impl Prover<state::Initialized> {
}
}
impl Prover<state::Setup> {
impl Prover<state::CommitAccepted> {
/// Connects to the server using the provided socket.
///
/// Returns a handle to the TLS connection, a future which returns the
/// prover once the connection is closed.
/// prover once the connection is closed and the TLS transcript is
/// committed.
///
/// # Arguments
///
/// * `config` - The TLS client configuration.
/// * `socket` - The socket to the server.
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
pub async fn connect<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
config: TlsClientConfig,
socket: S,
) -> Result<(TlsConnection, ProverFuture), ProverError> {
let state::Setup {
let state::CommitAccepted {
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -148,12 +160,13 @@ impl Prover<state::Setup> {
let (mpc_ctrl, mpc_fut) = mpc_tls.run();
let ServerName::Dns(server_name) = self.config.server_name();
let ServerName::Dns(server_name) = config.server_name();
let server_name =
TlsServerName::try_from(server_name.as_ref()).expect("name was validated");
let root_store = if let Some(root_store) = self.config.tls_config().root_store() {
let roots = root_store
let root_store = tls_client::RootCertStore {
roots: config
.root_store()
.roots
.iter()
.map(|cert| {
@@ -162,20 +175,15 @@ impl Prover<state::Setup> {
.map(|anchor| anchor.to_owned())
.map_err(ProverError::config)
})
.collect::<Result<Vec<_>, _>>()?;
tls_client::RootCertStore { roots }
} else {
tls_client::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
}
.collect::<Result<Vec<_>, _>>()?,
};
let config = tls_client::ClientConfig::builder()
let rustls_config = tls_client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store);
let config = if let Some((cert, key)) = self.config.tls_config().client_auth() {
config
let rustls_config = if let Some((cert, key)) = config.client_auth() {
rustls_config
.with_single_cert(
cert.iter()
.map(|cert| tls_client::Certificate(cert.0.clone()))
@@ -184,12 +192,15 @@ impl Prover<state::Setup> {
)
.map_err(ProverError::config)?
} else {
config.with_no_client_auth()
rustls_config.with_no_client_auth()
};
let client =
ClientConnection::new(Arc::new(config), Box::new(mpc_ctrl.clone()), server_name)
.map_err(ProverError::config)?;
let client = ClientConnection::new(
Arc::new(rustls_config),
Box::new(mpc_ctrl.clone()),
server_name,
)
.map_err(ProverError::config)?;
let (conn, conn_fut) = bind_client(socket, client);
@@ -263,6 +274,7 @@ impl Prover<state::Setup> {
mux_fut,
ctx,
vm,
server_name: config.server_name().clone(),
keys,
tls_transcript,
transcript,
@@ -305,21 +317,47 @@ impl Prover<state::Committed> {
ctx,
vm,
keys,
server_name,
tls_transcript,
transcript,
..
} = &mut self.state;
let handshake = config.server_identity().then(|| {
(
server_name.clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
});
let partial_transcript = config
.reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone()));
let msg = ProveRequestMsg {
request: config.to_request(),
handshake,
transcript: partial_transcript,
};
let output = mux_fut
.poll_with(prove::prove(
ctx,
vm,
keys,
self.config.server_name(),
transcript,
tls_transcript,
config,
))
.poll_with(async {
ctx.io_mut().send(msg).await.map_err(ProverError::from)?;
ctx.io_mut().expect_next::<Response>().await?.result?;
prove::prove(ctx, vm, keys, transcript, tls_transcript, config).await
})
.await?;
Ok(output)
@@ -342,53 +380,6 @@ impl Prover<state::Committed> {
}
}
fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsLeader) {
let mut rng = rand::rng();
let delta = Delta::new(Block::random(&mut rng));
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(
mpz_ot::kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_recv =
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
let rcot_recv = mpz_ot::ferret::Receiver::new(
mpz_ot::ferret::FerretConfig::builder()
.lpn_type(mpz_ot::ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_recv,
);
let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
let mpc = Mpc::new(
mpz_ot::cot::DerandCOTSender::new(rcot_send.clone()),
rng.random(),
delta,
);
let zk = Zk::new(ZkProverConfig::default(), rcot_recv.clone());
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Leader, mpc, zk)));
(
vm.clone(),
MpcTlsLeader::new(
config.build_mpc_tls_config(),
ctx,
vm,
(rcot_send.clone(), rcot_send.clone(), rcot_send),
rcot_recv,
),
)
}
/// A controller for the prover.
#[derive(Clone)]
pub struct ProverControl {
@@ -414,24 +405,3 @@ impl ProverControl {
.map_err(ProverError::from)
}
}
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(keys: &mut SessionKeys, vm: &Deap<Mpc, Zk>) -> Result<(), ProverError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(ProverError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(ProverError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(ProverError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(ProverError::mpc)?;
Ok(())
}

View File

@@ -1,144 +0,0 @@
use mpc_tls::Config;
use serde::{Deserialize, Serialize};
use tlsn_core::{
connection::ServerName,
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use crate::config::{NetworkSetting, ProtocolConfig};
/// Configuration for the prover.
#[derive(Debug, Clone, derive_builder::Builder, Serialize, Deserialize)]
pub struct ProverConfig {
/// The server DNS name.
#[builder(setter(into))]
server_name: ServerName,
/// Protocol configuration to be checked with the verifier.
protocol_config: ProtocolConfig,
/// TLS configuration.
#[builder(default)]
tls_config: TlsConfig,
}
impl ProverConfig {
/// Creates a new builder for `ProverConfig`.
pub fn builder() -> ProverConfigBuilder {
ProverConfigBuilder::default()
}
/// Returns the server DNS name.
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Returns the protocol configuration.
pub fn protocol_config(&self) -> &ProtocolConfig {
&self.protocol_config
}
/// Returns the TLS configuration.
pub fn tls_config(&self) -> &TlsConfig {
&self.tls_config
}
pub(crate) fn build_mpc_tls_config(&self) -> Config {
let mut builder = Config::builder();
builder
.defer_decryption(self.protocol_config.defer_decryption_from_start())
.max_sent(self.protocol_config.max_sent_data())
.max_recv_online(self.protocol_config.max_recv_data_online())
.max_recv(self.protocol_config.max_recv_data());
if let Some(max_sent_records) = self.protocol_config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = self.protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = self.protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}
/// Configuration for the prover's TLS connection.
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
/// Root certificates.
root_store: Option<RootCertStore>,
/// Certificate chain and a matching private key for client
/// authentication.
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfig {
/// Creates a new builder for `TlsConfig`.
pub fn builder() -> TlsConfigBuilder {
TlsConfigBuilder::default()
}
pub(crate) fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
/// Returns a certificate chain and a matching private key for client
/// authentication.
pub fn client_auth(&self) -> &Option<(Vec<CertificateDer>, PrivateKeyDer)> {
&self.client_auth
}
}
/// Builder for [`TlsConfig`].
#[derive(Debug, Default)]
pub struct TlsConfigBuilder {
root_store: Option<RootCertStore>,
client_auth: Option<(Vec<CertificateDer>, PrivateKeyDer)>,
}
impl TlsConfigBuilder {
/// Sets the root certificates to use for verifying the server's
/// certificate.
pub fn root_store(&mut self, store: RootCertStore) -> &mut Self {
self.root_store = Some(store);
self
}
/// Sets a DER-encoded certificate chain and a matching private key for
/// client authentication.
///
/// Often the chain will consist of a single end-entity certificate.
///
/// # Arguments
///
/// * `cert_key` - A tuple containing the certificate chain and the private
/// key.
///
/// - Each certificate in the chain must be in the X.509 format.
/// - The key must be in the ASN.1 format (either PKCS#8 or PKCS#1).
pub fn client_auth(&mut self, cert_key: (Vec<CertificateDer>, PrivateKeyDer)) -> &mut Self {
self.client_auth = Some(cert_key);
self
}
/// Builds the TLS configuration.
pub fn build(self) -> Result<TlsConfig, TlsConfigError> {
Ok(TlsConfig {
root_store: self.root_store,
client_auth: self.client_auth,
})
}
}
/// TLS configuration error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TlsConfigError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("tls config error")]
enum ErrorRepr {}

View File

@@ -4,7 +4,7 @@ use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Prover`](crate::Prover).
/// Error for [`Prover`](crate::prover::Prover).
#[derive(Debug, thiserror::Error)]
pub struct ProverError {
kind: ErrorKind,
@@ -49,6 +49,13 @@ impl ProverError {
{
Self::new(ErrorKind::Commit, source)
}
pub(crate) fn predicate<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Predicate, source)
}
}
#[derive(Debug)]
@@ -58,6 +65,7 @@ enum ErrorKind {
Zk,
Config,
Commit,
Predicate,
}
impl fmt::Display for ProverError {
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Config => f.write_str("config error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Predicate => f.write_str("predicate error")?,
}
if let Some(source) = &self.source {

View File

@@ -2,11 +2,10 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use serio::SinkExt;
use rangeset::set::RangeSet;
use tlsn_core::{
ProveConfig, ProveRequest, ProverOutput,
connection::{HandshakeData, ServerName},
ProverOutput,
config::prove::ProveConfig,
transcript::{
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
},
@@ -21,6 +20,7 @@ use crate::{
encoding::{self, MacStore},
hash::prove_hash,
},
predicate::prove_predicates,
},
};
@@ -28,7 +28,6 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
server_name: &ServerName,
transcript: &Transcript,
tls_transcript: &TlsTranscript,
config: &ProveConfig,
@@ -38,34 +37,6 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
transcript_secrets: Vec::default(),
};
let request = ProveRequest {
handshake: config.server_identity().then(|| {
(
server_name.clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: config
.reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())),
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
ctx.io_mut()
.send(request)
.await
.map_err(ProverError::from)?;
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = config.transcript_commit() {
@@ -84,6 +55,20 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
});
}
// Build predicate ranges from config
let (mut predicate_sent, mut predicate_recv) = (RangeSet::default(), RangeSet::default());
for predicate in config.predicates() {
let indices: RangeSet<usize> = predicate
.indices()
.into_iter()
.map(|idx| idx..idx + 1)
.collect();
match predicate.direction() {
Direction::Sent => predicate_sent.union_mut(&indices),
Direction::Received => predicate_recv.union_mut(&indices),
}
}
let transcript_refs = TranscriptRefs {
sent: prove_plaintext(
vm,
@@ -96,6 +81,7 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_sent,
&commit_sent,
&predicate_sent,
)
.map_err(ProverError::commit)?,
recv: prove_plaintext(
@@ -109,6 +95,7 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_recv,
&commit_recv,
&predicate_recv,
)
.map_err(ProverError::commit)?,
};
@@ -130,6 +117,12 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
None
};
// Prove predicates over transcript data
if !config.predicates().is_empty() {
prove_predicates(vm, &transcript_refs, config.predicates())
.map_err(ProverError::predicate)?;
}
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
if let Some(commit_config) = config.transcript_commit()

View File

@@ -4,13 +4,16 @@ use std::sync::Arc;
use mpc_tls::{MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use tlsn_core::transcript::{TlsTranscript, Transcript};
use tlsn_core::{
connection::ServerName,
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
mpz::{ProverMpc, ProverZk},
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
};
/// Entry state
@@ -18,23 +21,25 @@ pub struct Initialized;
opaque_debug::implement!(Initialized);
/// State after MPC setup has completed.
pub struct Setup {
/// State after the verifier has accepted the proposed TLS commitment protocol
/// configuration and preprocessing has completed.
pub struct CommitAccepted {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
pub(crate) vm: Arc<Mutex<Deap<ProverMpc, ProverZk>>>,
}
opaque_debug::implement!(Setup);
opaque_debug::implement!(CommitAccepted);
/// State after the TLS connection has been committed and closed.
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) vm: ProverZk,
pub(crate) server_name: ServerName,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
@@ -46,12 +51,12 @@ opaque_debug::implement!(Committed);
pub trait ProverState: sealed::Sealed {}
impl ProverState for Initialized {}
impl ProverState for Setup {}
impl ProverState for CommitAccepted {}
impl ProverState for Committed {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::Setup {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {}
}

View File

@@ -112,7 +112,7 @@ impl TagProof {
.map_err(TagProofError::vm)?
.ok_or_else(|| ErrorRepr::NotDecoded)?;
let aad = make_tls12_aad(rec.seq, rec.typ, vers, rec.ciphertext.len());
let aad = make_tls12_aad(rec.seq, rec.typ.into(), vers, rec.ciphertext.len());
let ghash_tag = ghash(aad.as_ref(), &rec.ciphertext, &mac_key);

View File

@@ -1,5 +1,6 @@
pub(crate) mod auth;
pub(crate) mod commit;
pub(crate) mod predicate;
use mpz_memory_core::{Vector, binary::U8};

View File

@@ -5,14 +5,14 @@ use ctr::{
Ctr32BE,
cipher::{KeyIvInit, StreamCipher, StreamCipherSeek},
};
use mpz_circuits::circuits::{AES128, xor};
use mpz_circuits::{AES128, circuits::xor};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
Array, DecodeFutureTyped, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{Difference, RangeSet, Union};
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
@@ -25,14 +25,15 @@ pub(crate) fn prove_plaintext<'a>(
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
predicate: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let is_reveal_all = reveal == (0..plaintext.len()) && predicate.is_empty();
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
commit.union(reveal).union(predicate).into_set()
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -49,7 +50,8 @@ pub(crate) fn prove_plaintext<'a>(
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal);
// Private ranges: committed but not revealed, plus predicate ranges
let private = commit.difference(reveal).union(predicate).into_set();
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -91,14 +93,15 @@ pub(crate) fn verify_plaintext<'a>(
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
predicate: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof<'a>), PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let is_reveal_all = reveal == (0..plaintext.len()) && predicate.is_empty();
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
commit.union(reveal).union(predicate).into_set()
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -123,9 +126,10 @@ pub(crate) fn verify_plaintext<'a>(
ciphertext,
})
} else {
let private = commit.difference(reveal);
// Blind ranges: committed but not revealed, plus predicate ranges
let blind = commit.difference(reveal).union(predicate).into_set();
for (_, slice) in plaintext_refs
.index(&private)
.index(&blind)
.expect("all ranges are allocated")
.iter()
{
@@ -175,15 +179,13 @@ fn alloc_plaintext(
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
}
fn alloc_ciphertext<'a>(
@@ -212,15 +214,13 @@ fn alloc_ciphertext<'a>(
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
}
fn alloc_keystream<'a>(
@@ -233,7 +233,7 @@ fn alloc_keystream<'a>(
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter_ranges();
let mut range_iter = ranges.iter();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
@@ -508,7 +508,7 @@ mod tests {
for record in records {
let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter_ranges() {
for mut range in ranges.iter() {
range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len);
if range.start < range.end {

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use rangeset::set::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
@@ -177,26 +177,10 @@ pub(crate) trait KeyStore {
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
}
impl KeyStore for crate::verifier::Zk {
fn delta(&self) -> &Delta {
crate::verifier::Zk::delta(self)
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
}
impl MacStore for crate::prover::Zk {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[derive(Debug)]
struct Provider {
sent: RangeMap<EncodingSlice>,

View File

@@ -3,13 +3,13 @@
use std::collections::HashMap;
use mpz_core::bitvec::BitVec;
use mpz_hash::{blake3::Blake3, sha256::Sha256};
use mpz_hash::{blake3::Blake3, keccak256::Keccak256, sha256::Sha256};
use mpz_memory_core::{
DecodeFutureTyped, MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet;
use rangeset::set::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
@@ -111,6 +111,7 @@ pub(crate) fn verify_hash(
enum Hasher {
Sha256(Sha256),
Blake3(Blake3),
Keccak256(Keccak256),
}
/// Commit plaintext hashes of the transcript.
@@ -154,7 +155,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
for range in idx.iter() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
@@ -175,7 +176,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
for range in idx.iter() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
@@ -185,6 +186,32 @@ fn hash_commit_inner(
.map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
HashAlgId::KECCAK256 => {
let mut hasher = if let Some(Hasher::Keccak256(hasher)) = hashers.get(&alg).cloned()
{
hasher
} else {
let hasher = Keccak256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Keccak256(hasher.clone()));
hasher
};
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
}
hasher
.update(vm, &blinder)
.map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => {
return Err(HashCommitError::unsupported_alg(alg));
}

View File

@@ -0,0 +1,175 @@
//! Predicate proving and verification over transcript data.
use std::sync::Arc;
use mpz_circuits::Circuit;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
DecodeFutureTyped, MemoryExt,
binary::{Binary, Bool},
};
use mpz_predicate::{Pred, compiler::Compiler};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::set::RangeSet;
use tlsn_core::{config::prove::PredicateConfig, transcript::Direction};
use super::{ReferenceMap, TranscriptRefs};
/// Error during predicate proving/verification.
#[derive(Debug, thiserror::Error)]
pub(crate) enum PredicateError {
/// Indices not found in transcript references.
#[error("predicate indices {0:?} not found in transcript references")]
IndicesNotFound(RangeSet<usize>),
/// VM error.
#[error("VM error: {0}")]
Vm(#[from] mpz_vm_core::VmError),
/// Circuit call error.
#[error("circuit call error: {0}")]
Call(#[from] mpz_vm_core::CallError),
/// Decode error.
#[error("decode error: {0}")]
Decode(#[from] mpz_memory_core::DecodeError),
/// Missing decoding.
#[error("missing decoding")]
MissingDecoding,
/// Predicate not satisfied.
#[error("predicate evaluated to false")]
PredicateNotSatisfied,
}
/// Converts a slice of indices to a RangeSet (each index becomes a single-byte
/// range).
fn indices_to_rangeset(indices: &[usize]) -> RangeSet<usize> {
indices.iter().map(|&idx| idx..idx + 1).collect()
}
/// Proves predicates over transcript data (prover side).
///
/// Each predicate is compiled to a circuit and executed with the corresponding
/// transcript bytes as input. The circuit outputs a single bit that must be
/// true.
pub(crate) fn prove_predicates<T: Vm<Binary>>(
vm: &mut T,
transcript_refs: &TranscriptRefs,
predicates: &[PredicateConfig],
) -> Result<(), PredicateError> {
let mut compiler = Compiler::new();
for predicate in predicates {
let refs = match predicate.direction() {
Direction::Sent => &transcript_refs.sent,
Direction::Received => &transcript_refs.recv,
};
// Compile predicate to circuit
let circuit = compiler.compile(predicate.predicate());
// Get indices from the predicate and convert to RangeSet
let indices = indices_to_rangeset(&predicate.indices());
// Prover doesn't need to verify output - they know their data satisfies the predicate
let _ = execute_predicate(vm, refs, &indices, &circuit)?;
}
Ok(())
}
/// Proof that predicates were satisfied.
///
/// Must be verified after `vm.execute_all()` completes.
#[must_use]
pub(crate) struct PredicateProof {
/// Decode futures for each predicate output.
outputs: Vec<DecodeFutureTyped<BitVec, bool>>,
}
impl PredicateProof {
/// Verifies that all predicates evaluated to true.
///
/// Must be called after `vm.execute_all()` completes.
pub(crate) fn verify(self) -> Result<(), PredicateError> {
for mut output in self.outputs {
let result = output
.try_recv()
.map_err(PredicateError::Decode)?
.ok_or(PredicateError::MissingDecoding)?;
if !result {
return Err(PredicateError::PredicateNotSatisfied);
}
}
Ok(())
}
}
/// Verifies predicates over transcript data (verifier side).
///
/// The verifier must provide the same predicates that the prover used,
/// looked up by predicate name from out-of-band agreement.
///
/// Returns a [`PredicateProof`] that must be verified after `vm.execute_all()`.
///
/// # Arguments
///
/// * `vm` - The zkVM.
/// * `transcript_refs` - References to transcript data in the VM.
/// * `predicates` - Iterator of (direction, indices, predicate) tuples.
pub(crate) fn verify_predicates<T: Vm<Binary>>(
vm: &mut T,
transcript_refs: &TranscriptRefs,
predicates: impl IntoIterator<Item = (Direction, RangeSet<usize>, Pred)>,
) -> Result<PredicateProof, PredicateError> {
let mut compiler = Compiler::new();
let mut outputs = Vec::new();
for (direction, indices, predicate) in predicates {
let refs = match direction {
Direction::Sent => &transcript_refs.sent,
Direction::Received => &transcript_refs.recv,
};
// Compile predicate to circuit
let circuit = compiler.compile(&predicate);
let output_fut = execute_predicate(vm, refs, &indices, &circuit)?;
outputs.push(output_fut);
}
Ok(PredicateProof { outputs })
}
/// Executes a predicate circuit with transcript bytes as input.
///
/// Returns a decode future for the circuit output.
fn execute_predicate<T: Vm<Binary>>(
vm: &mut T,
refs: &ReferenceMap,
indices: &RangeSet<usize>,
circuit: &Circuit,
) -> Result<DecodeFutureTyped<BitVec, bool>, PredicateError> {
// Get the transcript bytes for the predicate indices
let indexed_refs = refs
.index(indices)
.ok_or_else(|| PredicateError::IndicesNotFound(indices.clone()))?;
// Build the circuit call with transcript bytes as inputs
let circuit = Arc::new(circuit.clone());
let mut call_builder = Call::builder(circuit);
// Add each byte in the range as an input to the circuit
// The predicate circuit expects bytes in order, so we iterate through
// the indexed refs which maintains ordering
for (_range, vector) in indexed_refs.iter() {
call_builder = call_builder.arg(*vector);
}
let call = call_builder.build()?;
// Execute the circuit - output is a single bit (true/false)
// Both parties must call decode() on the output to reveal it
let output: Bool = vm.call(call)?;
// Return decode future - caller must verify output == true after execute_all
Ok(vm.decode(output)?)
}

View File

@@ -1,52 +1,38 @@
//! Verifier.
pub(crate) mod config;
mod error;
pub mod state;
mod verify;
use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
pub use tlsn_core::{
VerifierOutput, VerifyConfig, VerifyConfigBuilder, VerifyConfigBuilderError,
webpki::ServerCertVerifier,
};
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
pub use verify::PredicateResolver;
use crate::{
Role, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, tag::verify_tags,
Role,
context::build_mt_context,
mpz::{VerifierDeps, build_verifier_deps, translate_keys},
msg::{ProveRequestMsg, Response, TlsCommitRequestMsg},
mux::attach_mux,
tag::verify_tags,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
config::{
prove::ProveRequest,
tls_commit::{TlsCommitProtocolConfig, TlsCommitRequest},
verifier::VerifierConfig,
},
connection::{ConnectionInfo, ServerName},
transcript::TlsTranscript,
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
>;
pub(crate) type RCOTReceiver = mpz_ot::rcot::shared::SharedRCOTReceiver<
mpz_ot::kos::Receiver<mpz_ot::chou_orlandi::Sender>,
bool,
mpz_core::Block,
>;
pub(crate) type Mpc =
mpz_garble::protocol::semihonest::Evaluator<mpz_ot::cot::DerandCOTReceiver<RCOTReceiver>>;
pub(crate) type Zk = mpz_zk::Verifier<RCOTSender>;
/// Information about the TLS session.
#[derive(Debug)]
pub struct SessionInfo {
@@ -74,40 +60,86 @@ impl Verifier<state::Initialized> {
}
}
/// Sets up the verifier.
/// Starts the TLS commitment protocol.
///
/// This performs all MPC setup.
/// This initiates the TLS commitment protocol, receiving the prover's
/// configuration and providing the opportunity to accept or reject it.
///
/// # Arguments
///
/// * `socket` - The socket to the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn setup<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
pub async fn commit<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
) -> Result<Verifier<state::Setup>, VerifierError> {
) -> Result<Verifier<state::CommitStart>, VerifierError> {
let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier);
let mut mt = build_mt_context(mux_ctrl.clone());
let mut ctx = mux_fut.poll_with(mt.new_context()).await?;
// Receives protocol configuration from prover to perform compatibility check.
let protocol_config = mux_fut
.poll_with(async {
let peer_configuration: ProtocolConfig = ctx.io_mut().expect_next().await?;
self.config
.protocol_config_validator()
.validate(&peer_configuration)?;
let TlsCommitRequestMsg { request, version } =
mux_fut.poll_with(ctx.io_mut().expect_next()).await?;
Ok::<_, VerifierError>(peer_configuration)
})
.await?;
if version != *crate::VERSION {
let msg = format!(
"prover version does not match with verifier: {version} != {}",
*crate::VERSION
);
mux_fut
.poll_with(ctx.io_mut().send(Response::err(Some(msg.clone()))))
.await?;
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, ctx);
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
return Err(VerifierError::config(msg));
}
Ok(Verifier {
config: self.config,
span: self.span,
state: state::CommitStart {
mux_ctrl,
mux_fut,
ctx,
request,
},
})
}
}
impl Verifier<state::CommitStart> {
/// Returns the TLS commitment request.
pub fn request(&self) -> &TlsCommitRequest {
&self.state.request
}
/// Accepts the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn accept(self) -> Result<Verifier<state::CommitAccepted>, VerifierError> {
let state::CommitStart {
mux_ctrl,
mut mux_fut,
mut ctx,
request,
} = self.state;
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = request.protocol().clone() else {
unreachable!("only MPC TLS is supported");
};
let VerifierDeps { vm, mut mpc_tls } = build_verifier_deps(mpc_tls_config, ctx);
// Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
translate_keys(&mut keys, &vm_lock);
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -119,7 +151,7 @@ impl Verifier<state::Initialized> {
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Setup {
state: state::CommitAccepted {
mux_ctrl,
mux_fut,
mpc_tls,
@@ -129,35 +161,35 @@ impl Verifier<state::Initialized> {
})
}
/// Runs the TLS verifier to completion, verifying the TLS session.
///
/// This is a convenience method which runs all the steps needed for
/// verification.
///
/// # Arguments
///
/// * `socket` - The socket to the prover.
/// Rejects the proposed protocol configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify<S: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
self,
socket: S,
config: &VerifyConfig,
) -> Result<VerifierOutput, VerifierError> {
let mut verifier = self.setup(socket).await?.run().await?;
pub async fn reject(self, msg: Option<&str>) -> Result<(), VerifierError> {
let state::CommitStart {
mux_ctrl,
mut mux_fut,
mut ctx,
..
} = self.state;
let output = verifier.verify(config).await?;
mux_fut
.poll_with(ctx.io_mut().send(Response::err(msg)))
.await?;
verifier.close().await?;
// Wait for the prover to correctly close the connection.
if !mux_fut.is_complete() {
mux_ctrl.close();
mux_fut.await?;
}
Ok(output)
Ok(())
}
}
impl Verifier<state::Setup> {
impl Verifier<state::CommitAccepted> {
/// Runs the verifier until the TLS connection is closed.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn run(self) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Setup {
let state::CommitAccepted {
mux_ctrl,
mut mux_fut,
mpc_tls,
@@ -231,47 +263,41 @@ impl Verifier<state::Committed> {
&self.state.tls_transcript
}
/// Verifies information from the prover.
///
/// # Arguments
///
/// * `config` - Verification configuration.
/// Begins verification of statements from the prover.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn verify(
&mut self,
#[allow(unused_variables)] config: &VerifyConfig,
) -> Result<VerifierOutput, VerifierError> {
pub async fn verify(self) -> Result<Verifier<state::Verify>, VerifierError> {
let state::Committed {
mux_fut,
ctx,
mux_ctrl,
mut mux_fut,
mut ctx,
vm,
keys,
tls_transcript,
..
} = &mut self.state;
} = self.state;
let cert_verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let request = mux_fut
let ProveRequestMsg {
request,
handshake,
transcript,
} = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let output = mux_fut
.poll_with(verify::verify(
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Verify {
mux_ctrl,
mux_fut,
ctx,
vm,
keys,
&cert_verifier,
tls_transcript,
request,
))
.await?;
Ok(output)
handshake,
transcript,
},
})
}
/// Closes the connection with the prover.
@@ -291,73 +317,110 @@ impl Verifier<state::Committed> {
}
}
fn build_mpc_tls(
config: &VerifierConfig,
protocol_config: &ProtocolConfig,
ctx: Context,
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
let mut rng = rand::rng();
impl Verifier<state::Verify> {
/// Returns the proving request.
pub fn request(&self) -> &ProveRequest {
&self.state.request
}
let delta = Delta::random(&mut rng);
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(
mpz_ot::kos::SenderConfig::default(),
delta.into_inner(),
base_ot_recv,
);
let rcot_send = mpz_ot::ferret::Sender::new(
mpz_ot::ferret::FerretConfig::builder()
.lpn_type(mpz_ot::ferret::LpnType::Regular)
.build()
.expect("ferret config is valid"),
Block::random(&mut rng),
rcot_send,
);
let rcot_recv =
mpz_ot::kos::Receiver::new(mpz_ot::kos::ReceiverConfig::default(), base_ot_send);
/// Accepts the proving request.
///
/// Note: If the prover requests predicate verification, use
/// [`accept_with_predicates`](Self::accept_with_predicates) instead.
pub async fn accept(
self,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
self.accept_with_predicates(None).await
}
let rcot_send = mpz_ot::rcot::shared::SharedRCOTSender::new(rcot_send);
let rcot_recv = mpz_ot::rcot::shared::SharedRCOTReceiver::new(rcot_recv);
/// Accepts the proving request with predicate verification support.
///
/// # Arguments
///
/// * `predicate_resolver` - A function that resolves predicate names to
/// circuits. Required if the prover requests any predicates.
pub async fn accept_with_predicates(
self,
predicate_resolver: Option<&verify::PredicateResolver>,
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
let state::Verify {
mux_ctrl,
mut mux_fut,
mut ctx,
mut vm,
keys,
tls_transcript,
request,
handshake,
transcript,
} = self.state;
let mpc = Mpc::new(mpz_ot::cot::DerandCOTReceiver::new(rcot_recv.clone()));
mux_fut.poll_with(ctx.io_mut().send(Response::ok())).await?;
let zk = Zk::new(ZkVerifierConfig::default(), delta, rcot_send.clone());
let cert_verifier =
ServerCertVerifier::new(self.config.root_store()).map_err(VerifierError::config)?;
let vm = Arc::new(Mutex::new(Deap::new(tlsn_deap::Role::Follower, mpc, zk)));
let output = mux_fut
.poll_with(verify::verify(
&mut ctx,
&mut vm,
&keys,
&cert_verifier,
&tls_transcript,
request,
handshake,
transcript,
predicate_resolver,
))
.await?;
(
vm.clone(),
MpcTlsFollower::new(
config.build_mpc_tls_config(protocol_config),
ctx,
Ok((
output,
Verifier {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
keys,
tls_transcript,
},
},
))
}
/// Rejects the proving request.
pub async fn reject(
self,
msg: Option<&str>,
) -> Result<Verifier<state::Committed>, VerifierError> {
let state::Verify {
mux_ctrl,
mut mux_fut,
mut ctx,
vm,
rcot_send,
(rcot_recv.clone(), rcot_recv.clone(), rcot_recv),
),
)
}
keys,
tls_transcript,
..
} = self.state;
/// Translates VM references to the ZK address space.
fn translate_keys<Mpc, Zk>(
keys: &mut SessionKeys,
vm: &Deap<Mpc, Zk>,
) -> Result<(), VerifierError> {
keys.client_write_key = vm
.translate(keys.client_write_key)
.map_err(VerifierError::mpc)?;
keys.client_write_iv = vm
.translate(keys.client_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_key = vm
.translate(keys.server_write_key)
.map_err(VerifierError::mpc)?;
keys.server_write_iv = vm
.translate(keys.server_write_iv)
.map_err(VerifierError::mpc)?;
keys.server_write_mac_key = vm
.translate(keys.server_write_mac_key)
.map_err(VerifierError::mpc)?;
mux_fut
.poll_with(ctx.io_mut().send(Response::err(msg)))
.await?;
Ok(())
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
ctx,
vm,
keys,
tls_transcript,
},
})
}
}

View File

@@ -1,65 +0,0 @@
use std::fmt::{Debug, Formatter, Result};
use mpc_tls::Config;
use serde::{Deserialize, Serialize};
use tlsn_core::webpki::RootCertStore;
use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
/// Configuration for the [`Verifier`](crate::tls::Verifier).
#[allow(missing_docs)]
#[derive(derive_builder::Builder, Serialize, Deserialize)]
#[builder(pattern = "owned")]
pub struct VerifierConfig {
protocol_config_validator: ProtocolConfigValidator,
#[builder(default, setter(strip_option))]
root_store: Option<RootCertStore>,
}
impl Debug for VerifierConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("VerifierConfig")
.field("protocol_config_validator", &self.protocol_config_validator)
.finish_non_exhaustive()
}
}
impl VerifierConfig {
/// Creates a new configuration builder.
pub fn builder() -> VerifierConfigBuilder {
VerifierConfigBuilder::default()
}
/// Returns the protocol configuration validator.
pub fn protocol_config_validator(&self) -> &ProtocolConfigValidator {
&self.protocol_config_validator
}
/// Returns the root certificate store.
pub fn root_store(&self) -> Option<&RootCertStore> {
self.root_store.as_ref()
}
pub(crate) fn build_mpc_tls_config(&self, protocol_config: &ProtocolConfig) -> Config {
let mut builder = Config::builder();
builder
.max_sent(protocol_config.max_sent_data())
.max_recv_online(protocol_config.max_recv_data_online())
.max_recv(protocol_config.max_recv_data());
if let Some(max_sent_records) = protocol_config.max_sent_records() {
builder.max_sent_records(max_sent_records);
}
if let Some(max_recv_records_online) = protocol_config.max_recv_records_online() {
builder.max_recv_records_online(max_recv_records_online);
}
if let NetworkSetting::Latency = protocol_config.network() {
builder.low_bandwidth();
}
builder.build().unwrap()
}
}

View File

@@ -4,7 +4,7 @@ use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
/// Error for [`Verifier`](crate::Verifier).
/// Error for [`Verifier`](crate::verifier::Verifier).
#[derive(Debug, thiserror::Error)]
pub struct VerifierError {
kind: ErrorKind,
@@ -49,6 +49,13 @@ impl VerifierError {
{
Self::new(ErrorKind::Verify, source)
}
pub(crate) fn predicate<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Predicate, source)
}
}
#[derive(Debug)]
@@ -59,6 +66,7 @@ enum ErrorKind {
Zk,
Commit,
Verify,
Predicate,
}
impl fmt::Display for VerifierError {
@@ -72,6 +80,7 @@ impl fmt::Display for VerifierError {
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Verify => f.write_str("verification error")?,
ErrorKind::Predicate => f.write_str("predicate error")?,
}
if let Some(source) = &self.source {
@@ -88,12 +97,6 @@ impl From<std::io::Error> for VerifierError {
}
}
impl From<crate::config::ProtocolConfigError> for VerifierError {
fn from(e: crate::config::ProtocolConfigError) -> Self {
Self::new(ErrorKind::Config, e)
}
}
impl From<uid_mux::yamux::ConnectionError> for VerifierError {
fn from(e: uid_mux::yamux::ConnectionError) -> Self {
Self::new(ErrorKind::Io, e)

View File

@@ -5,11 +5,15 @@ use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::transcript::TlsTranscript;
use tlsn_core::{
config::{prove::ProveRequest, tls_commit::TlsCommitRequest},
connection::{HandshakeData, ServerName},
transcript::{PartialTranscript, TlsTranscript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::verifier::{Mpc, Zk};
use crate::mpz::{VerifierMpc, VerifierZk};
/// TLS Verifier state.
pub trait VerifierState: sealed::Sealed {}
@@ -19,34 +23,66 @@ pub struct Initialized;
opaque_debug::implement!(Initialized);
/// State after setup has completed.
pub struct Setup {
/// State after receiving protocol configuration from the prover.
pub struct CommitStart {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) request: TlsCommitRequest,
}
opaque_debug::implement!(CommitStart);
/// State after accepting the proposed TLS commitment protocol configuration and
/// performing preprocessing.
pub struct CommitAccepted {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
pub(crate) vm: Arc<Mutex<Deap<VerifierMpc, VerifierZk>>>,
}
/// State after the TLS connection has been closed.
opaque_debug::implement!(CommitAccepted);
/// State after the TLS transcript has been committed.
pub struct Committed {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
}
opaque_debug::implement!(Committed);
/// State after receiving a proving request.
pub struct Verify {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) vm: VerifierZk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) request: ProveRequest,
pub(crate) handshake: Option<(ServerName, HandshakeData)>,
pub(crate) transcript: Option<PartialTranscript>,
}
opaque_debug::implement!(Verify);
impl VerifierState for Initialized {}
impl VerifierState for Setup {}
impl VerifierState for CommitStart {}
impl VerifierState for CommitAccepted {}
impl VerifierState for Committed {}
impl VerifierState for Verify {}
mod sealed {
pub trait Sealed {}
impl Sealed for super::Initialized {}
impl Sealed for super::Setup {}
impl Sealed for super::CommitStart {}
impl Sealed for super::CommitAccepted {}
impl Sealed for super::Committed {}
impl Sealed for super::Verify {}
}

View File

@@ -1,10 +1,13 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_predicate::Pred;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use rangeset::set::RangeSet;
use tlsn_core::{
ProveRequest, VerifierOutput,
VerifierOutput,
config::prove::ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
},
@@ -19,10 +22,25 @@ use crate::{
encoding::{self, KeyStore},
hash::verify_hash,
},
predicate::verify_predicates,
},
verifier::VerifierError,
};
/// A function that resolves predicate names to predicates.
///
/// The verifier must provide this to look up predicates by name,
/// based on out-of-band agreement with the prover.
///
/// The function receives:
/// - The predicate name
/// - The byte indices the predicate operates on (from the prover's request)
///
/// The verifier should validate that the indices make sense for the predicate
/// and return the appropriate predicate built with those indices.
pub type PredicateResolver = dyn Fn(&str, &RangeSet<usize>) -> Option<Pred> + Send + Sync;
#[allow(clippy::too_many_arguments)]
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
@@ -30,18 +48,20 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
cert_verifier: &ServerCertVerifier,
tls_transcript: &TlsTranscript,
request: ProveRequest,
handshake: Option<(ServerName, HandshakeData)>,
transcript: Option<PartialTranscript>,
predicate_resolver: Option<&PredicateResolver>,
) -> Result<VerifierOutput, VerifierError> {
let ProveRequest {
handshake,
transcript,
transcript_commit,
} = request;
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
let has_reveal = transcript.is_some();
let transcript = if let Some(transcript) = transcript {
let transcript = if let Some((auth_sent, auth_recv)) = request.reveal() {
let Some(transcript) = transcript else {
return Err(VerifierError::verify(
"prover requested to reveal data but did not send transcript",
));
};
if transcript.len_sent() != ciphertext_sent.len()
|| transcript.len_received() != ciphertext_recv.len()
{
@@ -50,6 +70,18 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
));
}
if transcript.sent_authed() != auth_sent {
return Err(VerifierError::verify(
"prover sent transcript with incorrect sent authed data",
));
}
if transcript.received_authed() != auth_recv {
return Err(VerifierError::verify(
"prover sent transcript with incorrect received authed data",
));
}
transcript
} else {
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
@@ -71,7 +103,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
};
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = transcript_commit.as_ref() {
if let Some(commit_config) = request.transcript_commit() {
commit_config
.iter_hash()
.for_each(|(direction, idx, _)| match direction {
@@ -85,6 +117,15 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
}
}
// Build predicate ranges from request
let (mut predicate_sent, mut predicate_recv) = (RangeSet::default(), RangeSet::default());
for predicate_req in request.predicates() {
match predicate_req.direction() {
Direction::Sent => predicate_sent.union_mut(predicate_req.indices()),
Direction::Received => predicate_recv.union_mut(predicate_req.indices()),
}
}
let (sent_refs, sent_proof) = verify_plaintext(
vm,
keys.client_write_key,
@@ -97,6 +138,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
.filter(|record| record.typ == ContentType::ApplicationData),
transcript.sent_authed(),
&commit_sent,
&predicate_sent,
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = verify_plaintext(
@@ -111,6 +153,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
.filter(|record| record.typ == ContentType::ApplicationData),
transcript.received_authed(),
&commit_recv,
&predicate_recv,
)
.map_err(VerifierError::zk)?;
@@ -121,7 +164,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit.as_ref()
if let Some(commit_config) = request.transcript_commit()
&& commit_config.has_hash()
{
hash_commitments = Some(
@@ -130,13 +173,40 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
);
}
// Verify predicates if any were requested
let predicate_proof = if !request.predicates().is_empty() {
let resolver = predicate_resolver.ok_or_else(|| {
VerifierError::predicate("predicates requested but no resolver provided")
})?;
let predicates = request
.predicates()
.iter()
.map(|req| {
let predicate = resolver(req.name(), req.indices()).ok_or_else(|| {
VerifierError::predicate(format!("unknown predicate: {}", req.name()))
})?;
Ok((req.direction(), req.indices().clone(), predicate))
})
.collect::<Result<Vec<_>, VerifierError>>()?;
Some(verify_predicates(vm, &transcript_refs, predicates).map_err(VerifierError::predicate)?)
} else {
None
};
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
sent_proof.verify().map_err(VerifierError::verify)?;
recv_proof.verify().map_err(VerifierError::verify)?;
// Verify predicate outputs after ZK execution
if let Some(proof) = predicate_proof {
proof.verify().map_err(VerifierError::predicate)?;
}
let mut encoder_secret = None;
if let Some(commit_config) = transcript_commit
if let Some(commit_config) = request.transcript_commit()
&& let Some((sent, recv)) = commit_config.encoding()
{
let sent_map = transcript_refs
@@ -161,7 +231,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
Ok(VerifierOutput {
server_name,
transcript: has_reveal.then_some(transcript),
transcript: request.reveal().is_some().then_some(transcript),
encoder_secret,
transcript_commitments,
})

View File

@@ -1,15 +1,23 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use mpz_predicate::{Pred, eq};
use rangeset::set::RangeSet;
use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
config::{
prove::ProveConfig,
prover::ProverConfig,
tls::TlsClientConfig,
tls_commit::{TlsCommitConfig, mpc::MpcTlsConfig},
verifier::VerifierConfig,
},
connection::ServerName,
hash::{HashAlgId, HashProvider},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
prover::Prover,
transcript::{
Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment,
TranscriptCommitmentKind, TranscriptSecret,
},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
verifier::{Verifier, VerifierOutput},
webpki::{CertificateDer, RootCertStore},
};
use tlsn_core::ProverOutput;
use tlsn_server_fixture::bind;
@@ -44,19 +52,11 @@ async fn test() {
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete());
assert_eq!(
partial_transcript
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
partial_transcript.sent_authed().iter().next().unwrap(),
0..10
);
assert_eq!(
partial_transcript
.received_authed()
.iter_ranges()
.next()
.unwrap(),
partial_transcript.received_authed().iter().next().unwrap(),
0..10
);
@@ -113,35 +113,38 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let server_task = tokio::spawn(bind(server_socket.compat()));
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let prover = Prover::new(ProverConfig::builder().build().unwrap())
.commit(
TlsCommitConfig::builder()
.protocol(
MpcTlsConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_sent_records(MAX_SENT_RECORDS)
.max_recv_data(MAX_RECV_DATA)
.max_recv_records_online(MAX_RECV_RECORDS)
.build()
.unwrap(),
)
.build()
.unwrap(),
verifier_socket.compat(),
)
.await
.unwrap();
let tls_config = tls_config_builder.build().unwrap();
let server_name = ServerName::Dns(SERVER_DOMAIN.try_into().unwrap());
let prover = Prover::new(
ProverConfig::builder()
.server_name(server_name)
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_sent_records(MAX_SENT_RECORDS)
.max_recv_data(MAX_RECV_DATA)
.max_recv_records_online(MAX_RECV_RECORDS)
.build()
.unwrap(),
)
.build()
.unwrap(),
)
.setup(verifier_socket.compat())
.await
.unwrap();
let (mut tls_connection, prover_fut) = prover.connect(client_socket.compat()).await.unwrap();
let (mut tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()
.unwrap(),
client_socket.compat(),
)
.await
.unwrap();
let prover_task = tokio::spawn(prover_fut);
tls_connection
@@ -204,32 +207,209 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> VerifierOutput {
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()
.unwrap();
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()
.unwrap(),
);
let mut verifier = verifier
.setup(socket.compat())
let verifier = verifier
.commit(socket.compat())
.await
.unwrap()
.accept()
.await
.unwrap()
.run()
.await
.unwrap();
let output = verifier.verify(&VerifyConfig::default()).await.unwrap();
let (output, verifier) = verifier.verify().await.unwrap().accept().await.unwrap();
verifier.close().await.unwrap();
output
}
// Predicate name for testing
const TEST_PREDICATE: &str = "test_first_byte";
/// Test that a correct predicate passes verification.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
async fn test_predicate_passes() {
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/' (in "HTTP/1.1")
// Using index 10 to avoid overlap with revealed range (0..10)
// Verifier uses the same predicate - should pass
let prover_predicate = eq(10, b'/');
let (prover_result, verifier_result) = tokio::join!(
prover_with_predicate(socket_0, prover_predicate),
verifier_with_predicate(socket_1, || eq(10, b'/'))
);
prover_result.expect("prover should succeed");
verifier_result.expect("verifier should succeed with correct predicate");
}
/// Test that a wrong predicate is rejected by the verifier.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
async fn test_wrong_predicate_rejected() {
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/'
// Verifier uses a DIFFERENT predicate that checks for 'X' - should fail
let prover_predicate = eq(10, b'/');
let (prover_result, verifier_result) = tokio::join!(
prover_with_predicate(socket_0, prover_predicate),
verifier_with_predicate(socket_1, || eq(10, b'X'))
);
// Prover may succeed or fail depending on when verifier rejects
let _ = prover_result;
// Verifier should fail because predicate evaluates to false
assert!(
verifier_result.is_err(),
"verifier should reject wrong predicate"
);
}
/// Test that prover can't prove a predicate their data doesn't satisfy.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
async fn test_unsatisfied_predicate_rejected() {
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/'
// Both parties use eq(10, b'X') but prover's data has '/' at index 10
// This tests that a prover can't cheat - the predicate must actually be satisfied
let prover_predicate = eq(10, b'X');
let (prover_result, verifier_result) = tokio::join!(
prover_with_predicate(socket_0, prover_predicate),
verifier_with_predicate(socket_1, || eq(10, b'X'))
);
// Prover may succeed or fail depending on when verifier rejects
let _ = prover_result;
// Verifier should fail because prover's data doesn't satisfy the predicate
assert!(
verifier_result.is_err(),
"verifier should reject unsatisfied predicate"
);
}
#[instrument(skip(verifier_socket, predicate))]
async fn prover_with_predicate<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
predicate: Pred,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind(server_socket.compat()));
let prover = Prover::new(ProverConfig::builder().build()?)
.commit(
TlsCommitConfig::builder()
.protocol(
MpcTlsConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_sent_records(MAX_SENT_RECORDS)
.max_recv_data(MAX_RECV_DATA)
.max_recv_records_online(MAX_RECV_RECORDS)
.build()?,
)
.build()?,
verifier_socket.compat(),
)
.await?;
let (mut tls_connection, prover_fut) = prover
.connect(
TlsClientConfig::builder()
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
client_socket.compat(),
)
.await?;
let prover_task = tokio::spawn(prover_fut);
tls_connection
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
.await?;
tls_connection.close().await?;
let mut response = vec![0u8; 1024];
tls_connection.read_to_end(&mut response).await?;
let _ = server_task.await?;
let mut prover = prover_task.await??;
let mut builder = ProveConfig::builder(prover.transcript());
builder.server_identity();
builder.reveal_sent(&(0..10))?;
builder.reveal_recv(&(0..10))?;
builder.predicate(TEST_PREDICATE, Direction::Sent, predicate)?;
let config = builder.build()?;
prover.prove(&config).await?;
prover.close().await?;
Ok(())
}
async fn verifier_with_predicate<T, F>(
socket: T,
make_predicate: F,
) -> Result<VerifierOutput, Box<dyn std::error::Error + Send + Sync>>
where
T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static,
F: Fn() -> Pred + Send + Sync + 'static,
{
let verifier = Verifier::new(
VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.build()?,
);
let verifier = verifier
.commit(socket.compat())
.await?
.accept()
.await?
.run()
.await?;
let verifier = verifier.verify().await?;
// Resolver that builds the predicate fresh (Pred uses Rc, so can't be shared)
let predicate_resolver = move |name: &str, _indices: &RangeSet<usize>| -> Option<Pred> {
if name == TEST_PREDICATE {
Some(make_predicate())
} else {
None
}
};
let (output, verifier) = verifier
.accept_with_predicates(Some(&predicate_resolver))
.await?;
verifier.close().await?;
Ok(output)
}

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-wasm"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.14-pre"
edition = "2021"
repository = "https://github.com/tlsnotary/tlsn.git"
description = "A core WebAssembly package for TLSNotary."
@@ -21,7 +21,7 @@ no-bundler = ["web-spawn/no-bundler"]
[dependencies]
tlsn-core = { workspace = true }
tlsn = { workspace = true, features = ["web"] }
tlsn = { workspace = true, features = ["web", "mozilla-certs"] }
tlsn-server-fixture-certs = { workspace = true }
tlsn-tls-client-async = { workspace = true }
tlsn-tls-core = { workspace = true }

View File

@@ -1,11 +1,6 @@
use crate::types::NetworkSetting;
use serde::Deserialize;
use tlsn::{
config::{CertificateDer, PrivateKeyDer, ProtocolConfig},
connection::ServerName,
};
use tsify_next::Tsify;
use wasm_bindgen::JsError;
#[derive(Debug, Tsify, Deserialize)]
#[tsify(from_wasm_abi)]
@@ -20,66 +15,3 @@ pub struct ProverConfig {
pub network: NetworkSetting,
pub client_auth: Option<(Vec<Vec<u8>>, Vec<u8>)>,
}
impl TryFrom<ProverConfig> for tlsn::prover::ProverConfig {
type Error = JsError;
fn try_from(value: ProverConfig) -> Result<Self, Self::Error> {
let mut builder = ProtocolConfig::builder();
builder.max_sent_data(value.max_sent_data);
builder.max_recv_data(value.max_recv_data);
if let Some(value) = value.max_recv_data_online {
builder.max_recv_data_online(value);
}
if let Some(value) = value.max_sent_records {
builder.max_sent_records(value);
}
if let Some(value) = value.max_recv_records_online {
builder.max_recv_records_online(value);
}
if let Some(value) = value.defer_decryption_from_start {
builder.defer_decryption_from_start(value);
}
builder.network(value.network.into());
let protocol_config = builder.build().unwrap();
let mut builder = tlsn::prover::TlsConfig::builder();
if let Some((certs, key)) = value.client_auth {
let certs = certs
.into_iter()
.map(|cert| {
// Try to parse as PEM-encoded, otherwise assume DER.
if let Ok(cert) = CertificateDer::from_pem_slice(&cert) {
cert
} else {
CertificateDer(cert)
}
})
.collect();
let key = PrivateKeyDer(key);
builder.client_auth((certs, key));
}
let tls_config = builder.build().unwrap();
let server_name = ServerName::Dns(
value
.server_name
.try_into()
.map_err(|_| JsError::new("invalid server name"))?,
);
let mut builder = tlsn::prover::ProverConfig::builder();
builder
.server_name(server_name)
.protocol_config(protocol_config)
.tls_config(tls_config);
Ok(builder.build().unwrap())
}
}

View File

@@ -7,7 +7,16 @@ use futures::TryFutureExt;
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use tls_client_async::TlsConnection;
use tlsn::prover::{state, ProveConfig, Prover};
use tlsn::{
config::{
prove::ProveConfig,
tls::TlsClientConfig,
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig},
},
connection::ServerName,
prover::{state, Prover},
webpki::{CertificateDer, PrivateKeyDer, RootCertStore},
};
use tracing::info;
use wasm_bindgen::{prelude::*, JsError};
use wasm_bindgen_futures::spawn_local;
@@ -19,6 +28,7 @@ type Result<T> = std::result::Result<T, JsError>;
#[wasm_bindgen(js_name = Prover)]
pub struct JsProver {
config: ProverConfig,
state: State,
}
@@ -26,7 +36,7 @@ pub struct JsProver {
#[derive_err(Debug)]
enum State {
Initialized(Prover<state::Initialized>),
Setup(Prover<state::Setup>),
CommitAccepted(Prover<state::CommitAccepted>),
Committed(Prover<state::Committed>),
Complete,
Error,
@@ -43,7 +53,10 @@ impl JsProver {
#[wasm_bindgen(constructor)]
pub fn new(config: ProverConfig) -> Result<JsProver> {
Ok(JsProver {
state: State::Initialized(Prover::new(config.try_into()?)),
config,
state: State::Initialized(Prover::new(
tlsn::config::prover::ProverConfig::builder().build()?,
)),
})
}
@@ -54,15 +67,41 @@ impl JsProver {
pub async fn setup(&mut self, verifier_url: &str) -> Result<()> {
let prover = self.state.take().try_into_initialized()?;
let config = TlsCommitConfig::builder()
.protocol({
let mut builder = MpcTlsConfig::builder()
.max_sent_data(self.config.max_sent_data)
.max_recv_data(self.config.max_recv_data);
if let Some(value) = self.config.max_recv_data_online {
builder = builder.max_recv_data_online(value);
}
if let Some(value) = self.config.max_sent_records {
builder = builder.max_sent_records(value);
}
if let Some(value) = self.config.max_recv_records_online {
builder = builder.max_recv_records_online(value);
}
if let Some(value) = self.config.defer_decryption_from_start {
builder = builder.defer_decryption_from_start(value);
}
builder.network(self.config.network.into()).build()
}?)
.build()?;
info!("connecting to verifier");
let (_, verifier_conn) = WsMeta::connect(verifier_url, None).await?;
info!("connected to verifier");
let prover = prover.setup(verifier_conn.into_io()).await?;
let prover = prover.commit(config, verifier_conn.into_io()).await?;
self.state = State::Setup(prover);
self.state = State::CommitAccepted(prover);
Ok(())
}
@@ -73,7 +112,35 @@ impl JsProver {
ws_proxy_url: &str,
request: HttpRequest,
) -> Result<HttpResponse> {
let prover = self.state.take().try_into_setup()?;
let prover = self.state.take().try_into_commit_accepted()?;
let mut builder = TlsClientConfig::builder()
.server_name(ServerName::Dns(
self.config
.server_name
.clone()
.try_into()
.map_err(|_| JsError::new("invalid server name"))?,
))
.root_store(RootCertStore::mozilla());
if let Some((certs, key)) = self.config.client_auth.clone() {
let certs = certs
.into_iter()
.map(|cert| {
// Try to parse as PEM-encoded, otherwise assume DER.
if let Ok(cert) = CertificateDer::from_pem_slice(&cert) {
cert
} else {
CertificateDer(cert)
}
})
.collect();
let key = PrivateKeyDer(key);
builder = builder.client_auth((certs, key));
}
let config = builder.build()?;
info!("connecting to server");
@@ -81,7 +148,7 @@ impl JsProver {
info!("connected to server");
let (tls_conn, prover_fut) = prover.connect(server_conn.into_io()).await?;
let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?;
info!("sending request");
@@ -137,14 +204,6 @@ impl JsProver {
}
}
impl From<Prover<state::Initialized>> for JsProver {
fn from(value: Prover<state::Initialized>) -> Self {
JsProver {
state: State::Initialized(value),
}
}
}
async fn send_request(conn: TlsConnection, request: HttpRequest) -> Result<HttpResponse> {
let conn = FuturesIo::new(conn);
let request = hyper::Request::<Full<Bytes>>::try_from(request)?;

View File

@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
Self {
sent: value.sent_unsafe().to_vec(),
sent_authed: value.sent_authed().iter_ranges().collect(),
sent_authed: value.sent_authed().iter().collect(),
recv: value.received_unsafe().to_vec(),
recv_authed: value.received_authed().iter_ranges().collect(),
recv_authed: value.received_authed().iter().collect(),
}
}
}
@@ -181,7 +181,7 @@ pub struct VerifierOutput {
pub transcript: Option<PartialTranscript>,
}
#[derive(Debug, Tsify, Deserialize)]
#[derive(Debug, Clone, Copy, Tsify, Deserialize)]
#[tsify(from_wasm_abi)]
pub enum NetworkSetting {
/// Prefers a bandwidth-heavy protocol.
@@ -190,7 +190,7 @@ pub enum NetworkSetting {
Latency,
}
impl From<NetworkSetting> for tlsn::config::NetworkSetting {
impl From<NetworkSetting> for tlsn::config::tls_commit::mpc::NetworkSetting {
fn from(value: NetworkSetting) -> Self {
match value {
NetworkSetting::Bandwidth => Self::Bandwidth,

Some files were not shown because too many files have changed in this diff Show More