Compare commits

..

10 Commits

Author SHA1 Message Date
sinu
a56f2dd02b rename ProvePayload to ProveRequest 2025-09-11 12:51:23 -07:00
sinu
38a1ec3f72 import order 2025-09-11 09:43:38 -07:00
sinu
1501bc661f consolidate encoding stuff 2025-09-11 09:42:03 -07:00
sinu
0e2c2cb045 revert additional derives 2025-09-11 09:37:44 -07:00
sinu
a662fb7511 fix import order 2025-09-11 09:32:50 -07:00
th4s
be2e1ab95a adapt tests to new base 2025-09-11 18:15:14 +02:00
th4s
d34d135bfe adapt to new base 2025-09-11 10:55:26 +02:00
th4s
091d26bb63 fmt nightly 2025-09-11 10:13:08 +02:00
th4s
f031fe9a8d allow encoding protocol being executed only once 2025-09-11 10:13:08 +02:00
th4s
12c9a5eb34 refactor(tlsn): improve proving flow prover and verifier
- add `TlsTranscript` fixture for testing
- refactor and simplify the commit flow
- change how the TLS transcript is committed
  - tag verification is still done for the whole transcript
  - plaintext authentication is only done where needed
  - encoding adjustments are only transmitted for needed ranges
  - makes `prove` and `verify` functions more efficient and callable more than once
- decoding now supports server-write-key decoding without authentication
- add more tests
2025-09-11 10:13:06 +02:00
93 changed files with 5017 additions and 5457 deletions

View File

@@ -18,10 +18,10 @@ env:
# We need a higher number of parallel rayon tasks than the default (which is 4)
# in order to prevent a deadlock, c.f.
# - https://github.com/tlsnotary/tlsn/issues/548
# - https://github.com/privacy-ethereum/mpz/issues/178
# - https://github.com/privacy-scaling-explorations/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.90.0
RUST_VERSION: 1.89.0
jobs:
clippy:
@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@master
uses: dtolnay/rust-toolchain@stable
with:
toolchain: ${{ env.RUST_VERSION }}
components: clippy
@@ -41,7 +41,7 @@ jobs:
uses: Swatinem/rust-cache@v2.7.7
- name: Clippy
run: cargo clippy --keep-going --all-features --all-targets --locked
run: cargo clippy --keep-going --all-features --all-targets --locked -- -D warnings
fmt:
name: Check formatting

View File

@@ -6,7 +6,7 @@ on:
tag:
description: 'Tag to publish to NPM'
required: true
default: 'v0.1.0-alpha.13'
default: 'v0.1.0-alpha.13-pre'
jobs:
release:

View File

@@ -23,6 +23,7 @@ jobs:
- name: "rustdoc"
run: crates/wasm/build-docs.sh
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if: ${{ github.ref == 'refs/heads/dev' }}

2499
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,20 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", tag = "v0.1.0-alpha.4" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
@@ -101,7 +100,7 @@ bytes = { version = "1.4" }
cfg-if = { version = "1" }
chromiumoxide = { version = "0.7" }
chrono = { version = "0.4" }
cipher = { version = "0.4" }
cipher-crypto = { package = "cipher", version = "0.4" }
clap = { version = "4.5" }
criterion = { version = "0.5" }
ctr = { version = "0.9" }
@@ -111,7 +110,7 @@ elliptic-curve = { version = "0.13" }
enum-try-as-inner = { version = "0.1" }
env_logger = { version = "0.10" }
futures = { version = "0.3" }
futures-rustls = { version = "0.25" }
futures-rustls = { version = "0.26" }
generic-array = { version = "0.14" }
ghash = { version = "0.5" }
hex = { version = "0.4" }
@@ -124,6 +123,7 @@ inventory = { version = "0.3" }
itybity = { version = "0.2" }
js-sys = { version = "0.3" }
k256 = { version = "0.13" }
lipsum = { version = "0.9" }
log = { version = "0.4" }
once_cell = { version = "1.19" }
opaque-debug = { version = "0.3" }

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-attestation"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2024"
[features]
@@ -23,9 +23,9 @@ thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
[dev-dependencies]
alloy-primitives = { version = "1.3.1", default-features = false }
alloy-signer = { version = "1.0", default-features = false }
alloy-signer-local = { version = "1.0", default-features = false }
alloy-primitives = { version = "0.8.22", default-features = false }
alloy-signer = { version = "0.12", default-features = false }
alloy-signer-local = { version = "0.12", default-features = false }
rand06-compat = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }

View File

@@ -5,7 +5,7 @@ use rand::{Rng, rng};
use tlsn_core::{
connection::{ConnectionInfo, ServerEphemKey},
hash::HashAlgId,
transcript::{TranscriptCommitment, encoding::EncoderSecret},
transcript::TranscriptCommitment,
};
use crate::{
@@ -25,7 +25,6 @@ pub struct Sign {
connection_info: Option<ConnectionInfo>,
server_ephemeral_key: Option<ServerEphemKey>,
cert_commitment: ServerCertCommitment,
encoder_secret: Option<EncoderSecret>,
extensions: Vec<Extension>,
transcript_commitments: Vec<TranscriptCommitment>,
}
@@ -87,7 +86,6 @@ impl<'a> AttestationBuilder<'a, Accept> {
connection_info: None,
server_ephemeral_key: None,
cert_commitment,
encoder_secret: None,
transcript_commitments: Vec::new(),
extensions,
},
@@ -108,12 +106,6 @@ impl AttestationBuilder<'_, Sign> {
self
}
/// Sets the secret for encoding commitments.
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
self.state.encoder_secret = Some(secret);
self
}
/// Adds an extension to the attestation.
pub fn extension(&mut self, extension: Extension) -> &mut Self {
self.state.extensions.push(extension);
@@ -137,7 +129,6 @@ impl AttestationBuilder<'_, Sign> {
connection_info,
server_ephemeral_key,
cert_commitment,
encoder_secret,
extensions,
transcript_commitments,
} = self.state;
@@ -168,7 +159,6 @@ impl AttestationBuilder<'_, Sign> {
AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
})?),
cert_commitment: field_id.next(cert_commitment),
encoder_secret: encoder_secret.map(|secret| field_id.next(secret)),
extensions: extensions
.into_iter()
.map(|extension| field_id.next(extension))

View File

@@ -219,7 +219,7 @@ use tlsn_core::{
connection::{ConnectionInfo, ServerEphemKey},
hash::{Hash, HashAlgorithm, TypedHash},
merkle::MerkleTree,
transcript::{TranscriptCommitment, encoding::EncoderSecret},
transcript::TranscriptCommitment,
};
use crate::{
@@ -327,7 +327,6 @@ pub struct Body {
connection_info: Field<ConnectionInfo>,
server_ephemeral_key: Field<ServerEphemKey>,
cert_commitment: Field<ServerCertCommitment>,
encoder_secret: Option<Field<EncoderSecret>>,
extensions: Vec<Field<Extension>>,
transcript_commitments: Vec<Field<TranscriptCommitment>>,
}
@@ -373,7 +372,6 @@ impl Body {
connection_info: conn_info,
server_ephemeral_key,
cert_commitment,
encoder_secret,
extensions,
transcript_commitments,
} = self;
@@ -391,13 +389,6 @@ impl Body {
),
];
if let Some(encoder_secret) = encoder_secret {
fields.push((
encoder_secret.id,
hasher.hash_separated(&encoder_secret.data),
));
}
for field in extensions.iter() {
fields.push((field.id, hasher.hash_separated(&field.data)));
}

View File

@@ -91,11 +91,6 @@ impl Presentation {
transcript.verify_with_provider(
&provider.hash,
&attestation.body.connection_info().transcript_length,
attestation
.body
.encoder_secret
.as_ref()
.map(|field| &field.data),
attestation.body.transcript_commitments(),
)
})

View File

@@ -49,6 +49,5 @@ impl_domain_separator!(tlsn_core::connection::ConnectionInfo);
impl_domain_separator!(tlsn_core::connection::CertBinding);
impl_domain_separator!(tlsn_core::transcript::TranscriptCommitment);
impl_domain_separator!(tlsn_core::transcript::TranscriptSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncoderSecret);
impl_domain_separator!(tlsn_core::transcript::encoding::EncodingCommitment);
impl_domain_separator!(tlsn_core::transcript::hash::PlaintextHash);

View File

@@ -64,6 +64,7 @@ fn test_api() {
let encoding_commitment = EncodingCommitment {
root: encoding_tree.root(),
secret: encoder_secret(),
};
let request_config = RequestConfig::default();
@@ -95,7 +96,6 @@ fn test_api() {
.connection_info(connection_info.clone())
// Server key Notary received during handshake
.server_ephemeral_key(server_ephemeral_key)
.encoder_secret(encoder_secret())
.transcript_commitments(vec![TranscriptCommitment::Encoding(encoding_commitment)]);
let attestation = attestation_builder.build(&provider).unwrap();

View File

@@ -5,7 +5,7 @@ description = "This crate provides implementations of ciphers for two parties"
keywords = ["tls", "mpc", "2pc", "aes"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]
@@ -31,4 +31,4 @@ mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }
cipher-crypto = { workspace = true }

View File

@@ -344,8 +344,8 @@ mod tests {
start_ctr: usize,
msg: Vec<u8>,
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let mut full_iv = [0u8; 16];
@@ -365,7 +365,7 @@ mod tests {
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
use cipher_crypto::{BlockEncrypt, KeyInit};
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-deap"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF"
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "Implementation of the 3-party key-exchange protocol"
keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "Core types for TLSNotary"
keywords = ["tls", "mpc", "2pc", "types"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -6,7 +6,10 @@ use rustls_pki_types as webpki_types;
use serde::{Deserialize, Serialize};
use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams};
use crate::webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError};
use crate::{
transcript::TlsTranscript,
webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError},
};
/// TLS version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -116,75 +119,84 @@ pub enum KeyType {
SECP256R1 = 0x0017,
}
/// Signature algorithm used on the key exchange parameters.
/// Signature scheme on the key exchange parameters.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[allow(non_camel_case_types, missing_docs)]
pub enum SignatureAlgorithm {
ECDSA_NISTP256_SHA256,
ECDSA_NISTP256_SHA384,
ECDSA_NISTP384_SHA256,
ECDSA_NISTP384_SHA384,
ED25519,
RSA_PKCS1_2048_8192_SHA256,
RSA_PKCS1_2048_8192_SHA384,
RSA_PKCS1_2048_8192_SHA512,
RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
pub enum SignatureScheme {
RSA_PKCS1_SHA1 = 0x0201,
ECDSA_SHA1_Legacy = 0x0203,
RSA_PKCS1_SHA256 = 0x0401,
ECDSA_NISTP256_SHA256 = 0x0403,
RSA_PKCS1_SHA384 = 0x0501,
ECDSA_NISTP384_SHA384 = 0x0503,
RSA_PKCS1_SHA512 = 0x0601,
ECDSA_NISTP521_SHA512 = 0x0603,
RSA_PSS_SHA256 = 0x0804,
RSA_PSS_SHA384 = 0x0805,
RSA_PSS_SHA512 = 0x0806,
ED25519 = 0x0807,
}
impl fmt::Display for SignatureAlgorithm {
impl fmt::Display for SignatureScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SignatureAlgorithm::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureAlgorithm::ECDSA_NISTP256_SHA384 => write!(f, "ECDSA_NISTP256_SHA384"),
SignatureAlgorithm::ECDSA_NISTP384_SHA256 => write!(f, "ECDSA_NISTP384_SHA256"),
SignatureAlgorithm::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureAlgorithm::ED25519 => write!(f, "ED25519"),
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256 => {
write!(f, "RSA_PKCS1_2048_8192_SHA256")
}
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384 => {
write!(f, "RSA_PKCS1_2048_8192_SHA384")
}
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512 => {
write!(f, "RSA_PKCS1_2048_8192_SHA512")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA256_LEGACY_KEY")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA384_LEGACY_KEY")
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
write!(f, "RSA_PSS_2048_8192_SHA512_LEGACY_KEY")
}
SignatureScheme::RSA_PKCS1_SHA1 => write!(f, "RSA_PKCS1_SHA1"),
SignatureScheme::ECDSA_SHA1_Legacy => write!(f, "ECDSA_SHA1_Legacy"),
SignatureScheme::RSA_PKCS1_SHA256 => write!(f, "RSA_PKCS1_SHA256"),
SignatureScheme::ECDSA_NISTP256_SHA256 => write!(f, "ECDSA_NISTP256_SHA256"),
SignatureScheme::RSA_PKCS1_SHA384 => write!(f, "RSA_PKCS1_SHA384"),
SignatureScheme::ECDSA_NISTP384_SHA384 => write!(f, "ECDSA_NISTP384_SHA384"),
SignatureScheme::RSA_PKCS1_SHA512 => write!(f, "RSA_PKCS1_SHA512"),
SignatureScheme::ECDSA_NISTP521_SHA512 => write!(f, "ECDSA_NISTP521_SHA512"),
SignatureScheme::RSA_PSS_SHA256 => write!(f, "RSA_PSS_SHA256"),
SignatureScheme::RSA_PSS_SHA384 => write!(f, "RSA_PSS_SHA384"),
SignatureScheme::RSA_PSS_SHA512 => write!(f, "RSA_PSS_SHA512"),
SignatureScheme::ED25519 => write!(f, "ED25519"),
}
}
}
impl From<tls_core::verify::SignatureAlgorithm> for SignatureAlgorithm {
fn from(value: tls_core::verify::SignatureAlgorithm) -> Self {
use tls_core::verify::SignatureAlgorithm as Core;
impl TryFrom<tls_core::msgs::enums::SignatureScheme> for SignatureScheme {
type Error = &'static str;
fn try_from(value: tls_core::msgs::enums::SignatureScheme) -> Result<Self, Self::Error> {
use tls_core::msgs::enums::SignatureScheme as Core;
use SignatureScheme::*;
Ok(match value {
Core::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1,
Core::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy,
Core::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256,
Core::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256,
Core::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384,
Core::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384,
Core::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512,
Core::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512,
Core::RSA_PSS_SHA256 => RSA_PSS_SHA256,
Core::RSA_PSS_SHA384 => RSA_PSS_SHA384,
Core::RSA_PSS_SHA512 => RSA_PSS_SHA512,
Core::ED25519 => ED25519,
_ => return Err("unsupported signature scheme"),
})
}
}
impl From<SignatureScheme> for tls_core::msgs::enums::SignatureScheme {
fn from(value: SignatureScheme) -> Self {
use tls_core::msgs::enums::SignatureScheme::*;
match value {
Core::ECDSA_NISTP256_SHA256 => SignatureAlgorithm::ECDSA_NISTP256_SHA256,
Core::ECDSA_NISTP256_SHA384 => SignatureAlgorithm::ECDSA_NISTP256_SHA384,
Core::ECDSA_NISTP384_SHA256 => SignatureAlgorithm::ECDSA_NISTP384_SHA256,
Core::ECDSA_NISTP384_SHA384 => SignatureAlgorithm::ECDSA_NISTP384_SHA384,
Core::ED25519 => SignatureAlgorithm::ED25519,
Core::RSA_PKCS1_2048_8192_SHA256 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256,
Core::RSA_PKCS1_2048_8192_SHA384 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384,
Core::RSA_PKCS1_2048_8192_SHA512 => SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512,
Core::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY
}
Core::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY
}
Core::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY
}
SignatureScheme::RSA_PKCS1_SHA1 => RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy => ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256 => RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256 => ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384 => ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512 => ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256 => RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384 => RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512 => RSA_PSS_SHA512,
SignatureScheme::ED25519 => ED25519,
}
}
}
@@ -192,8 +204,8 @@ impl From<tls_core::verify::SignatureAlgorithm> for SignatureAlgorithm {
/// Server's signature of the key exchange parameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerSignature {
/// Signature algorithm.
pub alg: SignatureAlgorithm,
/// Signature scheme.
pub scheme: SignatureScheme,
/// Signature data.
pub sig: Vec<u8>,
}
@@ -303,6 +315,25 @@ pub struct HandshakeData {
}
impl HandshakeData {
/// Creates a new instance.
///
/// # Arguments
///
/// * `transcript` - The TLS transcript.
pub fn new(transcript: &TlsTranscript) -> Self {
Self {
certs: transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: transcript.certificate_binding().clone(),
}
}
/// Verifies the handshake data.
///
/// # Arguments
@@ -350,23 +381,20 @@ impl HandshakeData {
message.extend_from_slice(&server_ephemeral_key.kx_params());
use webpki::ring as alg;
let sig_alg = match self.sig.alg {
SignatureAlgorithm::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureAlgorithm::ECDSA_NISTP256_SHA384 => alg::ECDSA_P256_SHA384,
SignatureAlgorithm::ECDSA_NISTP384_SHA256 => alg::ECDSA_P384_SHA256,
SignatureAlgorithm::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureAlgorithm::ED25519 => alg::ED25519,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY
}
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY => {
alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY
let sig_alg = match self.sig.scheme {
SignatureScheme::RSA_PKCS1_SHA256 => alg::RSA_PKCS1_2048_8192_SHA256,
SignatureScheme::RSA_PKCS1_SHA384 => alg::RSA_PKCS1_2048_8192_SHA384,
SignatureScheme::RSA_PKCS1_SHA512 => alg::RSA_PKCS1_2048_8192_SHA512,
SignatureScheme::RSA_PSS_SHA256 => alg::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA384 => alg::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
SignatureScheme::RSA_PSS_SHA512 => alg::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
SignatureScheme::ECDSA_NISTP256_SHA256 => alg::ECDSA_P256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384 => alg::ECDSA_P384_SHA384,
SignatureScheme::ED25519 => alg::ED25519,
scheme => {
return Err(HandshakeVerificationError::UnsupportedSignatureScheme(
scheme,
))
}
};
@@ -396,6 +424,8 @@ pub enum HandshakeVerificationError {
InvalidServerEphemeralKey,
#[error("server certificate verification failed: {0}")]
ServerCert(ServerCertVerifierError),
#[error("unsupported signature scheme: {0}")]
UnsupportedSignatureScheme(SignatureScheme),
}
#[cfg(test)]

View File

@@ -10,8 +10,7 @@ use hex::FromHex;
use crate::{
connection::{
CertBinding, CertBindingV1_2, ConnectionInfo, DnsName, HandshakeData, KeyType,
ServerEphemKey, ServerName, ServerSignature, SignatureAlgorithm, TlsVersion,
TranscriptLength,
ServerEphemKey, ServerName, ServerSignature, SignatureScheme, TlsVersion, TranscriptLength,
},
transcript::{
encoding::{EncoderSecret, EncodingProvider},
@@ -48,7 +47,7 @@ impl ConnectionFixture {
CertificateDer(include_bytes!("fixtures/data/tlsnotary.org/ca.der").to_vec()),
],
sig: ServerSignature {
alg: SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256,
scheme: SignatureScheme::RSA_PKCS1_SHA256,
sig: Vec::<u8>::from_hex(include_bytes!(
"fixtures/data/tlsnotary.org/signature"
))
@@ -93,7 +92,7 @@ impl ConnectionFixture {
CertificateDer(include_bytes!("fixtures/data/appliedzkp.org/ca.der").to_vec()),
],
sig: ServerSignature {
alg: SignatureAlgorithm::ECDSA_NISTP256_SHA256,
scheme: SignatureScheme::ECDSA_NISTP256_SHA256,
sig: Vec::<u8>::from_hex(include_bytes!(
"fixtures/data/appliedzkp.org/signature"
))

View File

@@ -95,7 +95,7 @@ impl Display for HashAlgId {
}
/// A typed hash value.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TypedHash {
/// The algorithm of the hash.
pub alg: HashAlgId,
@@ -191,11 +191,6 @@ impl Hash {
len: value.len(),
}
}
/// Returns a byte slice of the hash value.
pub fn as_bytes(&self) -> &[u8] {
&self.value[..self.len]
}
}
impl rs_merkle::Hash for Hash {

View File

@@ -20,8 +20,8 @@ use serde::{Deserialize, Serialize};
use crate::{
connection::{HandshakeData, ServerName},
transcript::{
encoding::EncoderSecret, Direction, PartialTranscript, Transcript, TranscriptCommitConfig,
TranscriptCommitRequest, TranscriptCommitment, TranscriptSecret,
Direction, PartialTranscript, Transcript, TranscriptCommitConfig, TranscriptCommitRequest,
TranscriptCommitment, TranscriptSecret,
},
};
@@ -122,14 +122,6 @@ impl<'a> ProveConfigBuilder<'a> {
self.reveal(Direction::Sent, ranges)
}
/// Reveals all of the sent data transcript.
pub fn reveal_sent_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Sent);
let (sent, _) = self.reveal.get_or_insert_default();
sent.union_mut(&(0..len));
Ok(self)
}
/// Reveals the given ranges of the received data transcript.
pub fn reveal_recv(
&mut self,
@@ -138,12 +130,13 @@ impl<'a> ProveConfigBuilder<'a> {
self.reveal(Direction::Received, ranges)
}
/// Reveals all of the received data transcript.
pub fn reveal_recv_all(&mut self) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(Direction::Received);
let (_, recv) = self.reveal.get_or_insert_default();
recv.union_mut(&(0..len));
Ok(self)
/// Reveals the full transcript range for a given direction.
pub fn reveal_all(
&mut self,
direction: Direction,
) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(direction);
self.reveal(direction, &(0..len))
}
/// Builds the configuration.
@@ -218,6 +211,29 @@ pub struct ProveRequest {
pub transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Creates a new prove payload.
///
/// # Arguments
///
/// * `config` - The prove config.
/// * `transcript` - The partial transcript.
/// * `handshake` - The server name and handshake data.
pub fn new(
config: &ProveConfig,
transcript: Option<PartialTranscript>,
handshake: Option<(ServerName, HandshakeData)>,
) -> Self {
let transcript_commit = config.transcript_commit().map(|config| config.to_request());
Self {
handshake,
transcript,
transcript_commit,
}
}
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {
@@ -236,8 +252,6 @@ pub struct VerifierOutput {
pub server_name: Option<ServerName>,
/// Transcript data.
pub transcript: Option<PartialTranscript>,
/// Encoding commitment secret.
pub encoder_secret: Option<EncoderSecret>,
/// Transcript commitments.
pub transcript_commitments: Vec<TranscriptCommitment>,
}

View File

@@ -1,8 +1,8 @@
//! Transcript commitments.
use std::{collections::HashSet, fmt};
use std::fmt;
use rangeset::{ToRangeSet, UnionMut};
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
use crate::{
@@ -69,8 +69,6 @@ pub enum TranscriptSecret {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -85,16 +83,6 @@ impl TranscriptCommitConfig {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
@@ -114,19 +102,10 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding.then(|| {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (dir, idx) in self.iter_encoding() {
match dir {
Direction::Sent => sent.union_mut(idx),
Direction::Received => recv.union_mut(idx),
}
}
(sent, recv)
}),
encoding: self
.iter_encoding()
.map(|(dir, idx)| (*dir, idx.clone()))
.collect(),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -143,10 +122,8 @@ impl TranscriptCommitConfig {
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
impl<'a> TranscriptCommitConfigBuilder<'a> {
@@ -155,10 +132,8 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
Self {
transcript,
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
commits: Vec::default(),
}
}
@@ -200,14 +175,12 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
),
));
}
let value = ((direction, idx), kind);
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
if !self.commits.contains(&value) {
self.commits.push(value);
}
self.commits.insert(((direction, idx), kind));
Ok(self)
}
@@ -253,8 +226,6 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -301,30 +272,20 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: Option<(RangeSet<usize>, RangeSet<usize>)>,
encoding: Vec<(Direction, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn has_encoding(&self) -> bool {
self.encoding.is_some()
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()
/// Returns an iterator over the encoding commitments.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.encoding.iter()
}
/// Returns an iterator over the hash commitments.
pub fn iter_hash(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>, HashAlgId)> {
self.hash.iter()
}
/// Returns the ranges of the encoding commitments.
pub fn encoding(&self) -> Option<&(RangeSet<usize>, RangeSet<usize>)> {
self.encoding.as_ref()
}
}
#[cfg(test)]

View File

@@ -19,4 +19,6 @@ use crate::hash::TypedHash;
pub struct EncodingCommitment {
/// Merkle root of the encoding commitments.
pub root: TypedHash,
/// Seed used to generate the encodings.
pub secret: EncoderSecret,
}

View File

@@ -12,7 +12,7 @@ const BIT_ENCODING_SIZE: usize = 16;
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],

View File

@@ -8,7 +8,7 @@ use crate::{
merkle::{MerkleError, MerkleProof},
transcript::{
commit::MAX_TOTAL_COMMITTED_DATA,
encoding::{new_encoder, Encoder, EncoderSecret, EncodingCommitment},
encoding::{new_encoder, Encoder, EncodingCommitment},
Direction,
},
};
@@ -48,14 +48,13 @@ impl EncodingProof {
pub fn verify_with_provider(
&self,
provider: &HashProvider,
secret: &EncoderSecret,
commitment: &EncodingCommitment,
sent: &[u8],
recv: &[u8],
) -> Result<(RangeSet<usize>, RangeSet<usize>), EncodingProofError> {
let hasher = provider.get(&commitment.root.alg)?;
let encoder = new_encoder(secret);
let encoder = new_encoder(&commitment.secret);
let Self {
inclusion_proof,
openings,
@@ -233,7 +232,10 @@ mod test {
use crate::{
fixtures::{encoder_secret, encoder_secret_tampered_seed, encoding_provider},
hash::Blake3,
transcript::{encoding::EncodingTree, Transcript},
transcript::{
encoding::{EncoderSecret, EncodingTree},
Transcript,
},
};
use super::*;
@@ -244,7 +246,7 @@ mod test {
commitment: EncodingCommitment,
}
fn new_encoding_fixture() -> EncodingFixture {
fn new_encoding_fixture(secret: EncoderSecret) -> EncodingFixture {
let transcript = Transcript::new(POST_JSON, OK_JSON);
let idx_0 = (Direction::Sent, RangeSet::from(0..POST_JSON.len()));
@@ -255,7 +257,10 @@ mod test {
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment { root: tree.root() };
let commitment = EncodingCommitment {
root: tree.root(),
secret,
};
EncodingFixture {
transcript,
@@ -270,12 +275,11 @@ mod test {
transcript,
proof,
commitment,
} = new_encoding_fixture();
} = new_encoding_fixture(encoder_secret_tampered_seed());
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret_tampered_seed(),
&commitment,
transcript.sent(),
transcript.received(),
@@ -291,19 +295,13 @@ mod test {
transcript,
proof,
commitment,
} = new_encoding_fixture();
} = new_encoding_fixture(encoder_secret());
let sent = &transcript.sent()[transcript.sent().len() - 1..];
let recv = &transcript.received()[transcript.received().len() - 2..];
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
sent,
recv,
)
.verify_with_provider(&HashProvider::default(), &commitment, sent, recv)
.unwrap_err();
assert!(matches!(err.kind, ErrorKind::Proof));
@@ -315,7 +313,7 @@ mod test {
transcript,
mut proof,
commitment,
} = new_encoding_fixture();
} = new_encoding_fixture(encoder_secret());
let Opening { idx, .. } = proof.openings.values_mut().next().unwrap();
@@ -324,7 +322,6 @@ mod test {
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
@@ -340,7 +337,7 @@ mod test {
transcript,
mut proof,
commitment,
} = new_encoding_fixture();
} = new_encoding_fixture(encoder_secret());
let Opening { blinder, .. } = proof.openings.values_mut().next().unwrap();
@@ -349,7 +346,6 @@ mod test {
let err = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),

View File

@@ -222,12 +222,14 @@ mod tests {
let proof = tree.proof([&idx_0, &idx_1].into_iter()).unwrap();
let commitment = EncodingCommitment { root: tree.root() };
let commitment = EncodingCommitment {
root: tree.root(),
secret: encoder_secret(),
};
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),
@@ -258,12 +260,14 @@ mod tests {
.proof([&idx_0, &idx_1, &idx_2, &idx_3].into_iter())
.unwrap();
let commitment = EncodingCommitment { root: tree.root() };
let commitment = EncodingCommitment {
root: tree.root(),
secret: encoder_secret(),
};
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&encoder_secret(),
&commitment,
transcript.sent(),
transcript.received(),

View File

@@ -10,7 +10,7 @@ use crate::{
hash::{HashAlgId, HashProvider},
transcript::{
commit::{TranscriptCommitment, TranscriptCommitmentKind},
encoding::{EncoderSecret, EncodingProof, EncodingProofError, EncodingTree},
encoding::{EncodingProof, EncodingProofError, EncodingTree},
hash::{hash_plaintext, PlaintextHash, PlaintextHashSecret},
Direction, PartialTranscript, RangeSet, Transcript, TranscriptSecret,
},
@@ -22,9 +22,6 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
TranscriptCommitmentKind::Encoding,
];
@@ -51,7 +48,6 @@ impl TranscriptProof {
self,
provider: &HashProvider,
length: &TranscriptLength,
encoder_secret: Option<&EncoderSecret>,
commitments: impl IntoIterator<Item = &'a TranscriptCommitment>,
) -> Result<PartialTranscript, TranscriptProofError> {
let mut encoding_commitment = None;
@@ -87,13 +83,6 @@ impl TranscriptProof {
// Verify encoding proof.
if let Some(proof) = self.encoding_proof {
let secret = encoder_secret.ok_or_else(|| {
TranscriptProofError::new(
ErrorKind::Encoding,
"contains an encoding proof but missing encoder secret",
)
})?;
let commitment = encoding_commitment.ok_or_else(|| {
TranscriptProofError::new(
ErrorKind::Encoding,
@@ -103,7 +92,6 @@ impl TranscriptProof {
let (auth_sent, auth_recv) = proof.verify_with_provider(
provider,
secret,
commitment,
self.transcript.sent_unsafe(),
self.transcript.received_unsafe(),
@@ -584,7 +572,7 @@ mod tests {
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
use crate::{
fixtures::{encoder_secret, encoding_provider},
fixtures::encoding_provider,
hash::{Blake3, Blinder, HashAlgId},
transcript::TranscriptCommitConfigBuilder,
};
@@ -611,12 +599,7 @@ mod tests {
let provider = HashProvider::default();
let err = transcript_proof
.verify_with_provider(
&provider,
&transcript.length(),
Some(&encoder_secret()),
&[],
)
.verify_with_provider(&provider, &transcript.length(), &[])
.err()
.unwrap();
@@ -654,9 +637,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -664,6 +645,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {
@@ -690,7 +672,6 @@ mod tests {
.verify_with_provider(
&provider,
&transcript.length(),
None,
&[TranscriptCommitment::Hash(commitment)],
)
.unwrap();
@@ -702,9 +683,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_inconsistent_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -712,6 +691,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {
@@ -739,7 +719,6 @@ mod tests {
.verify_with_provider(
&provider,
&transcript.length(),
None,
&[TranscriptCommitment::Hash(commitment)],
)
.unwrap_err();

View File

@@ -24,7 +24,6 @@ hex = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
k256 = { workspace = true, features = ["ecdsa"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = [
"rt",
@@ -37,16 +36,11 @@ tokio = { workspace = true, features = [
tokio-util = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = ["barretenberg"] }
[[example]]
name = "interactive"
path = "interactive/interactive.rs"
[[example]]
name = "interactive_zk"
path = "interactive_zk/interactive_zk.rs"
[[example]]
name = "attestation_prove"
path = "attestation/prove.rs"

View File

@@ -175,7 +175,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
assert!(response.status() == StatusCode::OK);
// The prover task should be done now, so we can await it.
let prover = prover_task.await??;
let mut prover = prover_task.await??;
// Parse the HTTP transcript.
let transcript = HttpTranscript::parse(prover.transcript())?;
@@ -217,7 +217,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let request_config = builder.build()?;
let (attestation, secrets) = notarize(prover, &request_config, req_tx, resp_rx).await?;
let (attestation, secrets) = notarize(&mut prover, &request_config, req_tx, resp_rx).await?;
// Write the attestation to disk.
let attestation_path = tlsn_examples::get_file_path(example_type, "attestation");
@@ -238,7 +238,7 @@ async fn prover<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
}
async fn notarize(
mut prover: Prover<Committed>,
prover: &mut Prover<Committed>,
config: &RequestConfig,
request_tx: Sender<AttestationRequest>,
attestation_rx: Receiver<Attestation>,
@@ -255,11 +255,7 @@ async fn notarize(
transcript_commitments,
transcript_secrets,
..
} = prover.prove(&disclosure_config).await?;
let transcript = prover.transcript().clone();
let tls_transcript = prover.tls_transcript().clone();
prover.close().await?;
} = prover.prove(disclosure_config).await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);
@@ -267,17 +263,19 @@ async fn notarize(
builder
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap()))
.handshake_data(HandshakeData {
certs: tls_transcript
certs: prover
.tls_transcript()
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
sig: prover
.tls_transcript()
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
binding: prover.tls_transcript().certificate_binding().clone(),
})
.transcript(transcript)
.transcript(prover.transcript().clone())
.transcript_commitments(transcript_secrets, transcript_commitments);
let (request, secrets) = builder.build(&CryptoProvider::default())?;
@@ -329,7 +327,6 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
let VerifierOutput {
transcript_commitments,
encoder_secret,
..
} = verifier.verify(&VerifyConfig::default()).await?;
@@ -388,10 +385,6 @@ async fn notary<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.server_ephemeral_key(tls_transcript.server_ephemeral_key().clone())
.transcript_commitments(transcript_commitments);
if let Some(encoder_secret) = encoder_secret {
builder.encoder_secret(encoder_secret);
}
let attestation = builder.build(&provider)?;
// Send attestation to prover.

View File

@@ -174,7 +174,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -1,5 +0,0 @@
!noir/target/
# Ignore everything inside noir/target
noir/target/*
# Except noir.json
!noir/target/noir.json

View File

@@ -1,167 +0,0 @@
# Interactive Zero-Knowledge Age Verification with TLSNotary
This example demonstrates **privacy-preserving age verification** using TLSNotary and zero-knowledge proofs. It allows a prover to demonstrate they are 18+ years old without revealing their actual birth date or any other personal information.
## 🔍 How It Works (simplified overview)
```mermaid
sequenceDiagram
participant S as Tax Server<br/>(fixture)
participant P as Prover
participant V as Verifier
P->>S: Request tax data (with auth token) (MPC-TLS)
S->>P: Tax data including `date_of_birth` (MPC-TLS)
P->>V: Share transcript with redactions
P->>V: Commit to blinded hash of birth date
P->>P: Generate ZK proof of age ≥ 18
P->>V: Send ZK proof
V->>V: Verify transcript & ZK proof
V->>V: ✅ Confirm: Prover is 18+ (no birth date revealed)
```
### The Process
1. **MPC-TLS Session**: The Prover fetches tax information containing their birth date, while the Verifier jointly verifies the TLS session to ensure the data comes from the authentic server.
2. **Selective Disclosure**:
* The authorization token is **redacted**: the Verifier sees the plaintext request but not the token.
* The birth date is **committed** as a blinded hash: the Verifier cannot see the date, but the Prover is cryptographically bound to it.
(Depending on the use case more data can be redacted or revealed)
3. **Zero-Knowledge Proof**: The Prover generates a ZK proof that the committed birth date corresponds to an age ≥ 18.
4. **Verification**: The Verifier checks both the TLS transcript and the ZK proof, confirming age ≥ 18 without learning the actual date of birth.
### Example Data
The tax server returns data like this:
```json
{
"tax_year": 2024,
"taxpayer": {
"idnr": "12345678901",
"first_name": "Max",
"last_name": "Mustermann",
"date_of_birth": "1985-03-12",
// ...
}
}
```
## 🔐 Zero-Knowledge Proof Details
The ZK circuit proves: **"I know a birth date that hashes to the committed value AND indicates I am 18+ years old"**
**Public Inputs:**
- ✅ Verification date
- ✅ Committed blinded hash of birth date
**Private Inputs (Hidden):**
- 🔒 Actual birth date plaintext
- 🔒 Random blinder used in hash commitment
**What the Verifier Learns:**
- ✅ The prover is 18+ years old
- ✅ The birth date is authentic (from the MPC-TLS session)
Everything else remains private.
## 🏃 Run the Example
1. **Start the test server** (from repository root):
```bash
RUST_LOG=info PORT=4000 cargo run --bin tlsn-server-fixture
```
2. **Run the age verification** (in a new terminal):
```bash
SERVER_PORT=4000 cargo run --release --example interactive_zk
```
3. **For detailed logs**:
```bash
RUST_LOG=debug,yamux=info,uid_mux=info SERVER_PORT=4000 cargo run --release --example interactive_zk
```
### Expected Output
```
Successfully verified https://test-server.io:4000/elster
Age verified in ZK: 18+ ✅
Verified sent data:
GET https://test-server.io:4000/elster HTTP/1.1
host: test-server.io
connection: close
authorization: 🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈
Verified received data:
🙈🙈🙈🙈🙈🙈🙈🙈[truncated for brevity]...🙈🙈🙈🙈🙈"tax_year":2024🙈🙈🙈🙈🙈...
```
> 💡 **Note**: In this demo, both Prover and Verifier run on the same machine. In production, they would operate on separate systems.
> 💡 **Note**: This demo assumes that the tax server serves correct data, and that only the submitter of the tax data has access to the specified page.
## 🛠 Development
### Project Structure
```
interactive_zk/
├── prover.rs # Prover implementation
├── verifier.rs # Verifier implementation
├── types.rs # Shared types
└── interactive_zk.rs # Main example runner
├── noir/ # Zero-knowledge circuit
│ ├── src/main.n # Noir circuit code
│ ├── target/ # Compiled circuit artifacts
│ └── Nargo.toml # Noir project config
│ └── Prover.toml # Example input for `nargo execute`
│ └── generate_test_data.rs # Rust script to generate Noir test data
└── README.md
```
### Noir Circuit Commands
We use [Mopro's `noir_rs`](https://zkmopro.org/docs/crates/noir-rs/) for ZK proof generation. The **circuit is pre-compiled and ready to use**. You don't need to install Noir tools to run the example. But if you want to change or test the circuit in isolation, you can use the following instructions.
Before you proceed, we recommend to double check that your Noir tooling matches the versions used in Mopro's `noir_rs`:
```sh
# Install correct Noir and BB versions (important for compatibility!)
noirup --version 1.0.0-beta.8
bbup -v 1.0.0-nightly.20250723
```
If you don't have `noirup` and `bbup` installed yet, check [Noir's Quick Start](https://noir-lang.org/docs/getting_started/quick_start).
To compile the circuit, go to the `noir` folder and run `nargo compile`.
To check and experiment with the Noir circuit, you can use these commands:
* Execute Circuit: Compile the circuit and run it with sample data from `Prover.toml`:
```sh
nargo execute
```
* Generate Verification Key: Create the verification key needed to verify proofs
```sh
bb write_vk -b ./target/noir.json -o ./target
```
* Generate Proof: Create a zero-knowledge proof using the circuit and witness data.
```sh
bb prove --bytecode_path ./target/noir.json --witness_path ./target/noir.gz -o ./target
```
* Verify Proof: Verify that a proof is valid using the verification key.
```sh
bb verify -k ./target/vk -p ./target/proof
```
* Run the Noir tests:
```sh
nargo test --show-output
```
To create extra tests, you can use `./generate_test_data.rs` to help with generating correct blinders and hashes.
## 📚 Learn More
- [TLSNotary Documentation](https://docs.tlsnotary.org/)
- [Noir Language Guide](https://noir-lang.org/)
- [Zero-Knowledge Proofs Explained](https://ethereum.org/en/zero-knowledge-proofs/)
- [Mopro ZK Toolkit](https://zkmopro.org/)

View File

@@ -1,59 +0,0 @@
mod prover;
mod types;
mod verifier;
use prover::prover;
use std::{
env,
net::{IpAddr, SocketAddr},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::SERVER_DOMAIN;
use verifier::verifier;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// We use SERVER_DOMAIN here to make sure it matches the domain in the test
// server's certificate.
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/elster");
let server_ip: IpAddr = server_host
.parse()
.map_err(|e| format!("Invalid IP address '{}': {}", server_host, e))?;
let server_addr = SocketAddr::from((server_ip, server_port));
// Connect prover and verifier.
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
let (_, transcript) = tokio::try_join!(
prover(prover_socket, prover_extra_socket, &server_addr, &uri),
verifier(verifier_socket, verifier_extra_socket)
)?;
println!("---");
println!("Successfully verified {}", &uri);
println!("Age verified in ZK: 18+ ✅\n");
println!(
"Verified sent data:\n{}",
bytes_to_redacted_string(transcript.sent_unsafe())
);
println!(
"Verified received data:\n{}",
bytes_to_redacted_string(transcript.received_unsafe())
);
Ok(())
}
/// Render redacted bytes as `🙈`.
pub fn bytes_to_redacted_string(bytes: &[u8]) -> String {
String::from_utf8_lossy(bytes).replace('\0', "🙈")
}

View File

@@ -1,8 +0,0 @@
[package]
name = "noir"
type = "bin"
authors = [""]
[dependencies]
sha256 = { tag = "v0.1.5", git = "https://github.com/noir-lang/sha256" }
date = { tag = "v0.5.4", git = "https://github.com/madztheo/noir-date.git" }

View File

@@ -1,8 +0,0 @@
blinder = [108, 93, 120, 205, 15, 35, 159, 124, 243, 96, 22, 128, 16, 149, 219, 216]
committed_hash = [186, 158, 101, 39, 49, 48, 26, 83, 242, 96, 10, 221, 121, 174, 62, 50, 136, 132, 232, 58, 25, 32, 66, 196, 99, 85, 66, 85, 255, 1, 202, 254]
date_of_birth = "1985-03-12"
[proof_date]
day = "29"
month = "08"
year = "2025"

View File

@@ -1,64 +0,0 @@
#!/usr/bin/env -S cargo +nightly -Zscript
---
[package]
name = "generate_test_data"
version = "0.0.0"
edition = "2021"
publish = false
[dependencies]
sha2 = "0.10"
rand = "0.8"
chrono = "0.4"
---
use chrono::Datelike;
use chrono::Local;
use rand::RngCore;
use sha2::{Digest, Sha256};
fn main() {
// 1. Birthdate string (fixed)
let dob_str = "1985-03-12"; // 10 bytes long
let proof_date = Local::now().date_naive();
let proof_year = proof_date.year();
let proof_month = proof_date.month();
let proof_day = proof_date.day();
// 2. Generate random 16-byte blinder
let mut blinder = [0u8; 16];
rand::thread_rng().fill_bytes(&mut blinder);
// 3. Concatenate blinder + dob string bytes
let mut preimage = Vec::with_capacity(26);
preimage.extend_from_slice(dob_str.as_bytes());
preimage.extend_from_slice(&blinder);
// 4. Hash it
let hash = Sha256::digest(&preimage);
let blinder = blinder
.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join(", ");
let committed_hash = hash
.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join(", ");
println!(
"
// Private input
let date_of_birth = \"{dob_str}\";
let blinder = [{blinder}];
// Public input
let proof_date = date::Date {{ year: {proof_year}, month: {proof_month}, day: {proof_day} }};
let committed_hash = [{committed_hash}];
main(proof_date, committed_hash, date_of_birth, blinder);
"
);
}

View File

@@ -1,82 +0,0 @@
use dep::date::Date;
fn main(
// Public inputs
proof_date: pub date::Date, // "2025-08-29"
committed_hash: pub [u8; 32], // Hash of (blinder || dob string)
// Private inputs
date_of_birth: str<10>, // "1985-03-12"
blinder: [u8; 16], // Random 16-byte blinder
) {
let is_18 = check_18(date_of_birth, proof_date);
let correct_hash = check_hash(date_of_birth, blinder, committed_hash);
assert(correct_hash);
assert(is_18);
}
fn check_18(date_of_birth: str<10>, proof_date: date::Date) -> bool {
let dob = parse_birth_date(date_of_birth);
let is_18 = dob.add_years(18).lt(proof_date);
println(f"Is 18? {is_18}");
is_18
}
fn check_hash(date_of_birth: str<10>, blinder: [u8; 16], committed_hash: [u8; 32]) -> bool {
let hash_input: [u8; 26] = make_hash_input(date_of_birth, blinder);
let computed_hash = sha256::sha256_var(hash_input, 26);
let correct_hash = computed_hash == committed_hash;
println(f"Correct hash? {correct_hash}");
correct_hash
}
fn make_hash_input(dob: str<10>, blinder: [u8; 16]) -> [u8; 26] {
let mut input: [u8; 26] = [0; 26];
for i in 0..10 {
input[i] = dob.as_bytes()[i];
}
for i in 0..16 {
input[10 + i] = blinder[i];
}
input
}
pub fn parse_birth_date(birth_date: str<10>) -> date::Date {
let date: [u8; 10] = birth_date.as_bytes();
let date_str: str<8> =
[date[0], date[1], date[2], date[3], date[5], date[6], date[8], date[9]].as_str_unchecked();
Date::from_str_long_year(date_str)
}
#[test]
fn test_max_is_over_18() {
// Private input
let date_of_birth = "1985-03-12";
let blinder = [120, 80, 62, 10, 76, 60, 130, 98, 147, 161, 139, 126, 27, 236, 36, 56];
// Public input
let proof_date = date::Date { year: 2025, month: 9, day: 2 };
let committed_hash = [
229, 118, 202, 216, 213, 230, 125, 163, 48, 178, 118, 225, 84, 7, 140, 63, 173, 255, 163,
208, 163, 3, 63, 204, 37, 120, 254, 246, 202, 116, 122, 145,
];
main(proof_date, committed_hash, date_of_birth, blinder);
}
#[test(should_fail)]
fn test_under_18() {
// Private input
let date_of_birth = "2010-08-01";
let blinder = [160, 23, 57, 158, 141, 195, 155, 132, 109, 242, 48, 220, 70, 217, 229, 189];
// Public input
let proof_date = date::Date { year: 2025, month: 8, day: 29 };
let committed_hash = [
16, 132, 194, 62, 232, 90, 157, 153, 4, 231, 1, 54, 226, 3, 87, 174, 129, 177, 80, 69, 37,
222, 209, 91, 168, 156, 9, 109, 108, 144, 168, 109,
];
main(proof_date, committed_hash, date_of_birth, blinder);
}

File diff suppressed because one or more lines are too long

View File

@@ -1,371 +0,0 @@
use std::net::SocketAddr;
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use chrono::{Datelike, Local, NaiveDate};
use http_body_util::Empty;
use hyper::{body::Bytes, header, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use k256::sha2::{Digest, Sha256};
use noir::{
barretenberg::{
prove::prove_ultra_honk, srs::setup_srs_from_bytecode,
verify::get_ultra_honk_verification_key,
},
witness::from_vec_str_to_witness_map,
};
use serde_json::Value;
use spansy::{
http::{BodyContent, Requests, Responses},
Spanned,
};
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig},
transcript::{
hash::{PlaintextHash, PlaintextHashSecret},
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
TranscriptSecret,
},
};
use tlsn_examples::MAX_RECV_DATA;
use tokio::io::AsyncWriteExt;
use tlsn_examples::MAX_SENT_DATA;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
#[instrument(skip(verifier_socket, verifier_extra_socket))]
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
mut verifier_extra_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err("URI must use HTTPS scheme".into());
}
let server_domain = uri.authority().ok_or("URI must have authority")?.host();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into()?))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await?;
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?;
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
// MPC-TLS: Send Request and wait for Response.
let request = Request::builder()
.uri(uri.clone())
.header("Host", server_domain)
.header("Connection", "close")
.header(header::AUTHORIZATION, "Bearer random_auth_token")
.method("GET")
.body(Empty::<Bytes>::new())?;
let response = request_sender.send_request(request).await?;
if response.status() != StatusCode::OK {
return Err(format!("MPC-TLS request failed with status {}", response.status()).into());
}
// Create proof for the Verifier.
let mut prover = prover_task.await??;
let transcript = prover.transcript().clone();
let mut prove_config_builder = ProveConfig::builder(&transcript);
// Reveal the DNS name.
prove_config_builder.server_identity();
let sent: &[u8] = transcript.sent();
let received: &[u8] = transcript.received();
let sent_len = sent.len();
let recv_len = received.len();
tracing::info!("Sent length: {}, Received length: {}", sent_len, recv_len);
// Reveal the entire HTTP request except for the authorization bearer token
reveal_request(sent, &mut prove_config_builder)?;
// Create hash commitment for the date of birth field from the response
let mut transcript_commitment_builder = TranscriptCommitConfig::builder(&transcript);
transcript_commitment_builder.default_kind(TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
});
reveal_received(
received,
&mut prove_config_builder,
&mut transcript_commitment_builder,
)?;
let transcripts_commitment_config = transcript_commitment_builder.build()?;
prove_config_builder.transcript_commit(transcripts_commitment_config);
let prove_config = prove_config_builder.build()?;
// MPC-TLS prove
let prover_output = prover.prove(&prove_config).await?;
prover.close().await?;
// Prove birthdate is more than 18 years ago.
let received_commitments = received_commitments(&prover_output.transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("No received commitments found")?; // committed hash (of date of birth string)
let received_secrets = received_secrets(&prover_output.transcript_secrets);
let received_secret = received_secrets
.first()
.ok_or("No received secrets found")?; // hash blinder
let proof_input = prepare_zk_proof_input(received, received_commitment, received_secret)?;
let proof_bundle = generate_zk_proof(&proof_input)?;
// Sent zk proof bundle to verifier
let serialized_proof = bincode::serialize(&proof_bundle)?;
verifier_extra_socket.write_all(&serialized_proof).await?;
verifier_extra_socket.shutdown().await?;
Ok(())
}
// Reveal everything from the request, except for the authorization token.
fn reveal_request(
request: &[u8],
builder: &mut ProveConfigBuilder<'_>,
) -> Result<(), Box<dyn std::error::Error>> {
let reqs = Requests::new_from_slice(request).collect::<Result<Vec<_>, _>>()?;
let req = reqs.first().ok_or("No requests found")?;
if req.request.method.as_str() != "GET" {
return Err(format!("Expected GET method, found {}", req.request.method.as_str()).into());
}
let authorization_header = req
.headers_with_name(header::AUTHORIZATION.as_str())
.next()
.ok_or("Authorization header not found")?;
let start_pos = authorization_header
.span()
.indices()
.min()
.ok_or("Could not find authorization header start position")?
+ header::AUTHORIZATION.as_str().len()
+ 2;
let end_pos =
start_pos + authorization_header.span().len() - header::AUTHORIZATION.as_str().len() - 2;
builder.reveal_sent(&(0..start_pos))?;
builder.reveal_sent(&(end_pos..request.len()))?;
Ok(())
}
fn reveal_received(
received: &[u8],
builder: &mut ProveConfigBuilder<'_>,
transcript_commitment_builder: &mut TranscriptCommitConfigBuilder,
) -> Result<(), Box<dyn std::error::Error>> {
let resp = Responses::new_from_slice(received).collect::<Result<Vec<_>, _>>()?;
let response = resp.first().ok_or("No responses found")?;
let body = response.body.as_ref().ok_or("Response body not found")?;
let BodyContent::Json(json) = &body.content else {
return Err("Expected JSON body content".into());
};
// reveal tax year
let tax_year = json
.get("tax_year")
.ok_or("tax_year field not found in JSON")?;
let start_pos = tax_year
.span()
.indices()
.min()
.ok_or("Could not find tax_year start position")?
- 11;
let end_pos = tax_year
.span()
.indices()
.max()
.ok_or("Could not find tax_year end position")?
+ 1;
builder.reveal_recv(&(start_pos..end_pos))?;
// commit to hash of date of birth
let dob = json
.get("taxpayer.date_of_birth")
.ok_or("taxpayer.date_of_birth field not found in JSON")?;
transcript_commitment_builder.commit_recv(dob.span())?;
Ok(())
}
// extract secret from prover output
fn received_secrets(transcript_secrets: &[TranscriptSecret]) -> Vec<&PlaintextHashSecret> {
transcript_secrets
.iter()
.filter_map(|secret| match secret {
TranscriptSecret::Hash(hash) if hash.direction == Direction::Received => Some(hash),
_ => None,
})
.collect()
}
#[derive(Debug)]
pub struct ZKProofInput {
dob: Vec<u8>,
proof_date: NaiveDate,
blinder: Vec<u8>,
committed_hash: Vec<u8>,
}
// Verify that the blinded, committed hash is correct
fn prepare_zk_proof_input(
received: &[u8],
received_commitment: &PlaintextHash,
received_secret: &PlaintextHashSecret,
) -> Result<ZKProofInput, Box<dyn std::error::Error>> {
assert_eq!(received_commitment.direction, Direction::Received);
assert_eq!(received_commitment.hash.alg, HashAlgId::SHA256);
let hash = &received_commitment.hash;
let dob_start = received_commitment
.idx
.min()
.ok_or("No start index for DOB")?;
let dob_end = received_commitment
.idx
.end()
.ok_or("No end index for DOB")?;
let dob = received[dob_start..dob_end].to_vec();
let blinder = received_secret.blinder.as_bytes().to_vec();
let committed_hash = hash.value.as_bytes().to_vec();
let proof_date = Local::now().date_naive();
assert_eq!(received_secret.direction, Direction::Received);
assert_eq!(received_secret.alg, HashAlgId::SHA256);
let mut hasher = Sha256::new();
hasher.update(&dob);
hasher.update(&blinder);
let computed_hash = hasher.finalize();
if committed_hash != computed_hash.as_slice() {
return Err("Computed hash does not match committed hash".into());
}
Ok(ZKProofInput {
dob,
proof_date,
committed_hash,
blinder,
})
}
fn generate_zk_proof(
proof_input: &ZKProofInput,
) -> Result<ZKProofBundle, Box<dyn std::error::Error>> {
tracing::info!("🔒 Generating ZK proof with Noir...");
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
// 1. Load bytecode from program.json
let json: Value = serde_json::from_str(PROGRAM_JSON)?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("bytecode field not found in program.json")?;
let mut inputs: Vec<String> = vec![];
inputs.push(proof_input.proof_date.day().to_string());
inputs.push(proof_input.proof_date.month().to_string());
inputs.push(proof_input.proof_date.year().to_string());
inputs.extend(proof_input.committed_hash.iter().map(|b| b.to_string()));
inputs.extend(proof_input.dob.iter().map(|b| b.to_string()));
inputs.extend(proof_input.blinder.iter().map(|b| b.to_string()));
let proof_date = proof_input.proof_date.to_string();
tracing::info!(
"Public inputs : Proof date ({}) and committed hash ({})",
proof_date,
hex::encode(&proof_input.committed_hash)
);
tracing::info!(
"Private inputs: Blinder ({}) and Date of Birth ({})",
hex::encode(&proof_input.blinder),
String::from_utf8_lossy(&proof_input.dob)
);
tracing::debug!("Witness inputs {:?}", inputs);
let input_refs: Vec<&str> = inputs.iter().map(String::as_str).collect();
let witness = from_vec_str_to_witness_map(input_refs)?;
// Setup SRS
setup_srs_from_bytecode(bytecode, None, false)?;
// Verification key
let vk = get_ultra_honk_verification_key(bytecode, false)?;
// Generate proof
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)?;
tracing::info!("✅ Proof generated ({} bytes)", proof.len());
let proof_bundle = ZKProofBundle { vk, proof };
Ok(proof_bundle)
}

View File

@@ -1,21 +0,0 @@
use serde::{Deserialize, Serialize};
use tlsn::transcript::{hash::PlaintextHash, Direction, TranscriptCommitment};
#[derive(Serialize, Deserialize, Debug)]
pub struct ZKProofBundle {
pub vk: Vec<u8>,
pub proof: Vec<u8>,
}
// extract commitment from prover output
pub fn received_commitments(
transcript_commitments: &[TranscriptCommitment],
) -> Vec<&PlaintextHash> {
transcript_commitments
.iter()
.filter_map(|commitment| match commitment {
TranscriptCommitment::Hash(hash) if hash.direction == Direction::Received => Some(hash),
_ => None,
})
.collect()
}

View File

@@ -1,184 +0,0 @@
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use chrono::{Local, NaiveDate};
use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_honk};
use serde_json::Value;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::HashAlgId,
transcript::{Direction, PartialTranscript},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tlsn_server_fixture_certs::SERVER_DOMAIN;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::instrument;
#[instrument(skip(socket, extra_socket))]
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript, Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?;
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()?;
let verifier = Verifier::new(verifier_config);
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await?;
let server_name = server_name.ok_or("Prover should have revealed server name")?;
let transcript = transcript.ok_or("Prover should have revealed transcript data")?;
// Create hash commitment for the date of birth field from the response
let sent = transcript.sent_unsafe().to_vec();
let sent_data = String::from_utf8(sent.clone())
.map_err(|e| format!("Verifier expected valid UTF-8 sent data: {}", e))?;
if !sent_data.contains(SERVER_DOMAIN) {
return Err(format!(
"Verification failed: Expected host {} not found in sent data",
SERVER_DOMAIN
)
.into());
}
// Check received data.
let received_commitments = received_commitments(&transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("Missing received hash commitment")?;
assert!(received_commitment.direction == Direction::Received);
assert!(received_commitment.hash.alg == HashAlgId::SHA256);
let committed_hash = &received_commitment.hash;
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
if server_name.as_str() != SERVER_DOMAIN {
return Err(format!(
"Server name mismatch: expected {}, got {}",
SERVER_DOMAIN,
server_name.as_str()
)
.into());
}
// Receive ZKProof information from prover
let mut buf = Vec::new();
extra_socket.read_to_end(&mut buf).await?;
if buf.is_empty() {
return Err("No ZK proof data received from prover".into());
}
let msg: ZKProofBundle = bincode::deserialize(&buf)
.map_err(|e| format!("Failed to deserialize ZK proof bundle: {}", e))?;
// Verify zk proof
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
let json: Value = serde_json::from_str(PROGRAM_JSON)
.map_err(|e| format!("Failed to parse Noir circuit: {}", e))?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("Bytecode field missing in noir.json")?;
let vk = get_ultra_honk_verification_key(bytecode, false)
.map_err(|e| format!("Failed to get verification key: {}", e))?;
if vk != msg.vk {
return Err("Verification key mismatch between computed and provided by prover".into());
}
let proof = msg.proof.clone();
// Validate proof has enough data.
// The proof should start with the public inputs:
// * We expect at least 3 * 32 bytes for the three date fields (day, month,
// year)
// * and 32*32 bytes for the hash
let min_bytes = (32 + 3) * 32;
if proof.len() < min_bytes {
return Err(format!(
"Proof too short: expected at least {} bytes, got {}",
min_bytes,
proof.len()
)
.into());
}
// Check that the proof date is correctly included in the proof
let proof_date_day: u32 = u32::from_be_bytes(proof[28..32].try_into()?);
let proof_date_month: u32 = u32::from_be_bytes(proof[60..64].try_into()?);
let proof_date_year: i32 = i32::from_be_bytes(proof[92..96].try_into()?);
let proof_date_from_proof =
NaiveDate::from_ymd_opt(proof_date_year, proof_date_month, proof_date_day)
.ok_or("Invalid proof date in proof")?;
let today = Local::now().date_naive();
if (today - proof_date_from_proof).num_days() < 0 {
return Err(format!(
"The proof date can only be today or in the past: provided {}, today {}",
proof_date_from_proof, today
)
.into());
}
// Check that the committed hash in the proof matches the hash from the
// commitment
let committed_hash_in_proof: Vec<u8> = proof
.chunks(32)
.skip(3) // skip the first 3 chunks
.take(32)
.map(|chunk| *chunk.last().unwrap_or(&0))
.collect();
let expected_hash = committed_hash.value.as_bytes().to_vec();
if committed_hash_in_proof != expected_hash {
tracing::error!(
"❌ The hash in the proof does not match the committed hash in MPC-TLS: {} != {}",
hex::encode(&committed_hash_in_proof),
hex::encode(&expected_hash)
);
return Err("Hash in proof does not match committed hash in MPC-TLS".into());
}
tracing::info!(
"✅ The hash in the proof matches the committed hash in MPC-TLS ({})",
hex::encode(&expected_hash)
);
// Finally verify the proof
let is_valid = verify_ultra_honk(msg.proof, msg.vk)
.map_err(|e| format!("ZKProof Verification failed: {}", e))?;
if !is_valid {
tracing::error!("❌ Age verification ZKProof failed to verify");
return Err("Age verification ZKProof failed to verify".into());
}
tracing::info!("✅ Age verification ZKProof successfully verified");
Ok(transcript)
}

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-formats"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -3,19 +3,7 @@
# Ensure the script runs in the folder that contains this script
cd "$(dirname "$0")"
RUNNER_FEATURES=""
EXECUTOR_FEATURES=""
if [ "$1" = "debug" ]; then
RUNNER_FEATURES="--features debug"
EXECUTOR_FEATURES="--no-default-features --features debug"
fi
cargo build --release \
--package tlsn-harness-runner $RUNNER_FEATURES \
--package tlsn-harness-executor $EXECUTOR_FEATURES \
--package tlsn-server-fixture \
--package tlsn-harness-plot
cargo build --release --package tlsn-harness-runner --package tlsn-harness-executor --package tlsn-server-fixture --package tlsn-harness-plot
mkdir -p bin

View File

@@ -1,14 +1,10 @@
[target.wasm32-unknown-unknown]
rustflags = [
"-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-Clink-arg=--shared-memory",
"-C",
"target-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-C",
# 4GB
"-Clink-arg=--max-memory=4294967296",
"-Clink-arg=--import-memory",
"-Clink-arg=--export=__wasm_init_tls",
"-Clink-arg=--export=__tls_size",
"-Clink-arg=--export=__tls_align",
"-Clink-arg=--export=__tls_base",
"link-arg=--max-memory=4294967296",
"--cfg",
'getrandom_backend="wasm_js"',
]

View File

@@ -4,12 +4,6 @@ version = "0.1.0"
edition = "2024"
publish = false
[features]
# Disable tracing events as a workaround for issue 959.
default = ["tracing/release_max_level_off"]
# Used to debug the executor itself.
debug = []
[lib]
name = "harness_executor"
crate-type = ["cdylib", "rlib"]
@@ -34,7 +28,8 @@ tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true, features = ["compat"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
tracing = { workspace = true }
# Disable tracing events as a workaround for issue 959.
tracing = { workspace = true, features = ["release_max_level_off"] }
wasm-bindgen = { workspace = true }
tlsn-wasm = { workspace = true }
js-sys = { workspace = true }

View File

@@ -93,7 +93,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -107,7 +107,7 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -1,8 +1,6 @@
FROM rust AS builder
WORKDIR /usr/src/tlsn
ARG DEBUG=0
RUN \
rustup update; \
apt update && apt install -y clang; \
@@ -12,12 +10,7 @@ RUN \
COPY . .
RUN \
cd crates/harness; \
# Pass `--build-arg DEBUG=1` to `docker build` if you need to debug the harness.
if [ "$DEBUG" = "1" ]; then \
./build.sh debug; \
else \
./build.sh; \
fi
./build.sh;
FROM debian:latest

View File

@@ -7,10 +7,6 @@ publish = false
[lib]
name = "harness_runner"
[features]
# Used to debug the runner itself.
debug = []
[dependencies]
tlsn-harness-core = { workspace = true }
tlsn-server-fixture = { workspace = true }

View File

@@ -1,17 +0,0 @@
#![allow(unused_imports)]
pub use futures::FutureExt;
pub use tracing::{debug, error};
pub use chromiumoxide::{
Browser, Page,
cdp::{
browser_protocol::{
log::{EventEntryAdded, LogEntryLevel},
network::{EnableParams, SetCacheDisabledParams},
page::ReloadParams,
},
js_protocol::runtime::EventExceptionThrown,
},
handler::HandlerConfig,
};

View File

@@ -21,9 +21,6 @@ use harness_core::{
use crate::{Target, network::Namespace, rpc::Rpc};
#[cfg(feature = "debug")]
use crate::debug_prelude::*;
pub struct Executor {
ns: Namespace,
config: ExecutorConfig,
@@ -69,34 +66,20 @@ impl Executor {
Id::One => self.config.network().rpc_1,
};
let mut args = vec![
"ip".into(),
"netns".into(),
"exec".into(),
self.ns.name().into(),
"env".into(),
let process = duct::cmd!(
"sudo",
"ip",
"netns",
"exec",
self.ns.name(),
"env",
format!("CONFIG={}", serde_json::to_string(&self.config)?),
];
if cfg!(feature = "debug") {
let level = &std::env::var("RUST_LOG").unwrap_or("debug".to_string());
args.push("env".into());
args.push(format!("RUST_LOG={}", level));
};
args.push(executor_path.to_str().expect("valid path").into());
let process = duct::cmd("sudo", args);
let process = if !cfg!(feature = "debug") {
process
.stdout_capture()
.stderr_capture()
.unchecked()
.start()?
} else {
process.unchecked().start()?
};
executor_path
)
.stdout_capture()
.stderr_capture()
.unchecked()
.start()?;
let rpc = Rpc::new_native(rpc_addr).await?;
@@ -136,13 +119,10 @@ impl Executor {
"--no-sandbox",
format!("--user-data-dir={tmp}"),
format!("--allowed-ips=10.250.0.1"),
);
let process = if !cfg!(feature = "debug") {
process.stderr_capture().stdout_capture().start()?
} else {
process.start()?
};
)
.stderr_capture()
.stdout_capture()
.start()?;
const TIMEOUT: usize = 10000;
const DELAY: usize = 100;
@@ -191,38 +171,6 @@ impl Executor {
.new_page(&format!("http://{wasm_addr}:{wasm_port}/index.html"))
.await?;
#[cfg(feature = "debug")]
tokio::spawn(register_listeners(page.clone()).await?);
#[cfg(feature = "debug")]
async fn register_listeners(page: Page) -> Result<impl Future<Output = ()>> {
let mut logs = page.event_listener::<EventEntryAdded>().await?.fuse();
let mut exceptions =
page.event_listener::<EventExceptionThrown>().await?.fuse();
Ok(futures::future::join(
async move {
while let Some(event) = logs.next().await {
let entry = &event.entry;
match entry.level {
LogEntryLevel::Error => {
error!("{:?}", entry);
}
_ => {
debug!("{:?}: {}", entry.timestamp, entry.text);
}
}
}
},
async move {
while let Some(event) = exceptions.next().await {
error!("{:?}", event);
}
},
)
.map(|_| ()))
}
page.execute(EnableParams::builder().build()).await?;
page.execute(SetCacheDisabledParams {
cache_disabled: true,

View File

@@ -6,9 +6,6 @@ mod server_fixture;
pub mod wasm_server;
mod ws_proxy;
#[cfg(feature = "debug")]
mod debug_prelude;
use std::time::Duration;
use anyhow::Result;
@@ -27,9 +24,6 @@ use cli::{Cli, Command};
use executor::Executor;
use server_fixture::ServerFixture;
#[cfg(feature = "debug")]
use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
@@ -119,9 +113,6 @@ impl Runner {
}
pub async fn main() -> Result<()> {
#[cfg(feature = "debug")]
tracing_subscriber::fmt::init();
let cli = Cli::parse();
let mut runner = Runner::new(&cli)?;
@@ -236,9 +227,6 @@ pub async fn main() -> Result<()> {
// Wait for the network to stabilize
tokio::time::sleep(Duration::from_millis(100)).await;
#[cfg(feature = "debug")]
debug!("Starting bench in group {:?}", config.group);
let (output, _) = tokio::try_join!(
runner.exec_p.bench(BenchCmd {
config: config.clone(),

View File

@@ -5,7 +5,7 @@ description = "TLSNotary MPC-TLS protocol"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -41,7 +41,6 @@ use tls_core::{
message::{OpaqueMessage, PlainMessage},
},
suites::SupportedCipherSuite,
verify::verify_sig_determine_alg,
};
use tlsn_core::{
connection::{CertBinding, CertBindingV1_2, ServerSignature, TlsVersion, VerifyData},
@@ -328,20 +327,12 @@ impl MpcTlsLeader {
.map(|cert| CertificateDer(cert.0.clone()))
.collect();
let mut sig_msg = Vec::new();
sig_msg.extend_from_slice(&client_random.0);
sig_msg.extend_from_slice(&server_random.0);
sig_msg.extend_from_slice(server_kx_details.kx_params());
let server_signature_alg = verify_sig_determine_alg(
&server_cert_details.cert_chain()[0],
&sig_msg,
server_kx_details.kx_sig(),
)
.expect("only supported signature should have been accepted");
let server_signature = ServerSignature {
alg: server_signature_alg.into(),
scheme: server_kx_details
.kx_sig()
.scheme
.try_into()
.expect("only supported signature scheme should have been accepted"),
sig: server_kx_details.kx_sig().sig.0.clone(),
};

View File

@@ -72,5 +72,4 @@ pub(crate) struct ServerFinishedVd {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub(crate) struct CloseConnection;

View File

@@ -193,7 +193,7 @@ where
};
// Divide by block length and round up.
let block_count = input.len() / 16 + !input.len().is_multiple_of(16) as usize;
let block_count = input.len() / 16 + (input.len() % 16 != 0) as usize;
if block_count > MAX_POWER {
return Err(ErrorRepr::InputLength {
@@ -282,11 +282,11 @@ fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + !aad.len().is_multiple_of(16) as usize;
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + !ciphertext.len().is_multiple_of(16) as usize;
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);

View File

@@ -1,37 +0,0 @@
{
"tax_year": 2024,
"taxpayer": {
"idnr": "12345678901",
"first_name": "Max",
"last_name": "Mustermann",
"date_of_birth": "1985-03-12",
"address": {
"street": "Musterstraße 1",
"postal_code": "10115",
"city": "Berlin"
}
},
"income": {
"employment_income": 54200.00,
"other_income": 1200.00,
"capital_gains": 350.00
},
"deductions": {
"pension_insurance": 4200.00,
"health_insurance": 3600.00,
"donations": 500.00,
"work_related_expenses": 1100.00
},
"assessment": {
"taxable_income": 49200.00,
"income_tax": 9156.00,
"solidarity_surcharge": 503.58,
"total_tax": 9659.58,
"prepaid_tax": 9500.00,
"refund": 159.58
},
"submission": {
"submitted_at": "2025-03-01T14:22:30Z",
"submitted_by": "ElsterOnline-Portal"
}
}

View File

@@ -47,7 +47,6 @@ fn app(state: AppState) -> Router {
.route("/formats/json", get(json))
.route("/formats/html", get(html))
.route("/protected", get(protected_route))
.route("/elster", get(elster_route))
.layer(TraceLayer::new_for_http())
.with_state(Arc::new(Mutex::new(state)))
}
@@ -197,12 +196,6 @@ async fn protected_route(_: AuthenticatedUser) -> Result<Json<Value>, StatusCode
get_json_value(include_str!("data/protected_data.json"))
}
async fn elster_route(_: AuthenticatedUser) -> Result<Json<Value>, StatusCode> {
info!("Handling /elster");
get_json_value(include_str!("data/elster.json"))
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -5,7 +5,7 @@ description = "A TLS backend trait for TLSNotary"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "An async TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "async"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -5,7 +5,7 @@ description = "A TLS client for TLSNotary"
keywords = ["tls", "mpc", "2pc", "client", "sync"]
categories = ["cryptography"]
license = "Apache-2.0 OR ISC OR MIT"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
autobenches = false

View File

@@ -5,7 +5,7 @@ description = "Cryptographic operations for the TLSNotary TLS client"
keywords = ["tls", "mpc", "2pc"]
categories = ["cryptography"]
license = "Apache-2.0 OR ISC OR MIT"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
[lints]

View File

@@ -465,81 +465,19 @@ fn convert_scheme(scheme: SignatureScheme) -> Result<SignatureAlgorithms, Error>
}
}
/// Signature algorithm.
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(non_camel_case_types)]
pub enum SignatureAlgorithm {
ECDSA_NISTP256_SHA256,
ECDSA_NISTP256_SHA384,
ECDSA_NISTP384_SHA256,
ECDSA_NISTP384_SHA384,
ED25519,
RSA_PKCS1_2048_8192_SHA256,
RSA_PKCS1_2048_8192_SHA384,
RSA_PKCS1_2048_8192_SHA512,
RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
}
impl SignatureAlgorithm {
pub fn from_alg(alg: &dyn pki_types::SignatureVerificationAlgorithm) -> Self {
let id = alg.signature_alg_id();
if id == webpki::ring::ECDSA_P256_SHA256.signature_alg_id() {
SignatureAlgorithm::ECDSA_NISTP256_SHA256
} else if id == webpki::ring::ECDSA_P256_SHA384.signature_alg_id() {
SignatureAlgorithm::ECDSA_NISTP256_SHA384
} else if id == webpki::ring::ECDSA_P384_SHA256.signature_alg_id() {
SignatureAlgorithm::ECDSA_NISTP384_SHA256
} else if id == webpki::ring::ECDSA_P384_SHA384.signature_alg_id() {
SignatureAlgorithm::ECDSA_NISTP384_SHA384
} else if id == webpki::ring::ED25519.signature_alg_id() {
SignatureAlgorithm::ED25519
} else if id == webpki::ring::RSA_PKCS1_2048_8192_SHA256.signature_alg_id() {
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA256
} else if id == webpki::ring::RSA_PKCS1_2048_8192_SHA384.signature_alg_id() {
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA384
} else if id == webpki::ring::RSA_PKCS1_2048_8192_SHA512.signature_alg_id() {
SignatureAlgorithm::RSA_PKCS1_2048_8192_SHA512
} else if id == webpki::ring::RSA_PSS_2048_8192_SHA256_LEGACY_KEY.signature_alg_id() {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA256_LEGACY_KEY
} else if id == webpki::ring::RSA_PSS_2048_8192_SHA384_LEGACY_KEY.signature_alg_id() {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA384_LEGACY_KEY
} else if id == webpki::ring::RSA_PSS_2048_8192_SHA512_LEGACY_KEY.signature_alg_id() {
SignatureAlgorithm::RSA_PSS_2048_8192_SHA512_LEGACY_KEY
} else {
unreachable!()
}
}
}
/// Verify the signature and return the algorithm which passed verification.
pub fn verify_sig_determine_alg(
cert: &Certificate,
message: &[u8],
dss: &DigitallySignedStruct,
) -> Result<SignatureAlgorithm, Error> {
let cert = pki_types::CertificateDer::from(cert.0.as_slice());
let cert = webpki::EndEntityCert::try_from(&cert).map_err(pki_error)?;
verify_sig_using_any_alg(&cert, convert_scheme(dss.scheme)?, message, &dss.sig.0)
.map_err(pki_error)
}
fn verify_sig_using_any_alg(
cert: &webpki::EndEntityCert,
algs: SignatureAlgorithms,
message: &[u8],
sig: &[u8],
) -> Result<SignatureAlgorithm, webpki::Error> {
) -> Result<(), webpki::Error> {
// TLS doesn't itself give us enough info to map to a single
// webpki::SignatureAlgorithm. Therefore, convert_algs maps to several and
// we try them all.
for alg in algs {
match cert.verify_signature(*alg, message, sig) {
Ok(_) => return Ok(SignatureAlgorithm::from_alg(*alg)),
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKeyContext(_)) => continue,
Err(e) => return Err(e),
res => return res,
}
}

View File

@@ -4,7 +4,7 @@ authors = ["TLSNotary Team"]
keywords = ["tls", "mpc", "2pc", "prover"]
categories = ["cryptography"]
license = "MIT OR Apache-2.0"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2024"
[lints]
@@ -31,7 +31,6 @@ web-spawn = { workspace = true, optional = true }
mpz-common = { workspace = true }
mpz-core = { workspace = true }
mpz-circuits = { workspace = true }
mpz-garble = { workspace = true }
mpz-garble-core = { workspace = true }
mpz-hash = { workspace = true }
@@ -42,6 +41,7 @@ mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
aes = { workspace = true }
cipher-crypto = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
@@ -60,9 +60,9 @@ rangeset = { workspace = true }
webpki-roots = { workspace = true }
[dev-dependencies]
mpz-ideal-vm = { workspace = true }
lipsum = { workspace = true }
sha2 = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
tokio = { workspace = true, features = ["full"] }
@@ -70,3 +70,5 @@ tokio-util = { workspace = true, features = ["compat"] }
hyper = { workspace = true, features = ["client"] }
http-body-util = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tlsn-core = { workspace = true, features = ["fixtures"] }
mpz-ot = { workspace = true, features = ["ideal"] }

941
crates/tlsn/src/commit.rs Normal file
View File

@@ -0,0 +1,941 @@
//! Proving flow for prover and verifier. Handles necessary parts for commitment
//! creation and selective disclosure.
//!
//! - transcript reference storage in [`transcript`]
//! - transcript's plaintext authentication in [`auth`]
//! - decoding of transcript in [`decode`]
//! - encoding commitments in [`encoding`]
//! - hash commitments in [`hash`]
mod auth;
mod decode;
mod encoding;
mod hash;
mod transcript;
pub(crate) use encoding::ENCODING_SIZE;
pub(crate) use transcript::TranscriptRefs;
use encoding::{EncodingError, Encodings};
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_garble_core::Delta;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::VmError;
use rand::Rng;
use rangeset::RangeSet;
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
ProveConfig, ProveRequest, ProverOutput, VerifierOutput,
connection::{HandshakeData, HandshakeVerificationError, ServerName},
hash::{HashAlgId, TypedHash},
transcript::{
Direction, PartialTranscript, TlsTranscript, Transcript, TranscriptCommitment,
TranscriptSecret,
encoding::{EncoderSecret, EncodingCommitment, EncodingTree},
},
webpki::{RootCertStore, ServerCertVerifier, ServerCertVerifierError},
};
use crate::{
commit::{
auth::{AuthError, Authenticator},
decode::{DecodeError, check_transcript_length, decode_transcript, verify_transcript},
encoding::{EncodingCreator, EncodingMemory, EncodingVm},
hash::{HashCommitError, PlaintextHasher},
},
zk_aes_ctr::ZkAesCtr,
};
/// Internal proving state used by [`Prover`](crate::prover::Prover) and
/// [`Verifier`](crate::verifier::Verifier).
///
/// Manages the prover and verifier flow. Bundles plaintext authentication,
/// creation of commitments, selective disclosure and verification of
/// servername.
pub(crate) struct ProvingState<'a> {
partial: Option<PartialTranscript>,
server_identity: Option<(ServerName, HandshakeData)>,
verified_server_name: Option<ServerName>,
authenticator: Authenticator,
encoding: EncodingCreator,
encodings_transferred: bool,
hasher: PlaintextHasher,
tls_transcript: &'a TlsTranscript,
transcript_refs: &'a mut TranscriptRefs,
}
impl<'a> ProvingState<'a> {
/// Creates a new proving state for the prover.
///
/// # Arguments
///
/// * `config` - The config for proving.
/// * `tls_transcript` - The TLS transcript.
/// * `transcript` - The transcript.
/// * `transcript_refs` - The transcript references.
/// * `encodings_transferred` - If the encoding protocol has already been
/// executed.
pub(crate) fn for_prover(
config: ProveConfig,
tls_transcript: &'a TlsTranscript,
transcript: &Transcript,
transcript_refs: &'a mut TranscriptRefs,
encodings_transferred: bool,
) -> Self {
let mut encoding_hash_id = None;
let mut encoding_idxs: Vec<(Direction, RangeSet<usize>)> = Vec::new();
let mut hash_idxs: Vec<(Direction, RangeSet<usize>, HashAlgId)> = Vec::new();
if let Some(commit_config) = config.transcript_commit() {
encoding_hash_id = Some(*commit_config.encoding_hash_alg());
encoding_idxs = commit_config
.iter_encoding()
.map(|(dir, idx)| (*dir, idx.clone()))
.collect();
hash_idxs = commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
.collect();
}
let partial = if let Some((sent_reveal, recv_reveal)) = config.reveal() {
Some(transcript.to_partial(sent_reveal.clone(), recv_reveal.clone()))
} else {
None
};
let authenticator =
Authenticator::new(encoding_idxs.iter(), hash_idxs.iter(), partial.as_ref());
let encoding = EncodingCreator::new(encoding_hash_id, encoding_idxs);
let hasher = PlaintextHasher::new(hash_idxs.iter());
Self {
partial,
server_identity: None,
verified_server_name: None,
authenticator,
encoding,
encodings_transferred,
hasher,
tls_transcript,
transcript_refs,
}
}
/// Creates a new proving state for the verifier.
///
/// # Arguments
///
/// * `payload` - The prove payload.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
/// * `verified_server_name` - The verified server name.
/// * `encodings_transferred` - If the encoding protocol has already been
/// executed.
pub(crate) fn for_verifier(
payload: ProveRequest,
transcript: &'a TlsTranscript,
transcript_refs: &'a mut TranscriptRefs,
verified_server_name: Option<ServerName>,
encodings_transferred: bool,
) -> Self {
let mut encoding_idxs: Vec<(Direction, RangeSet<usize>)> = Vec::new();
let mut hash_idxs: Vec<(Direction, RangeSet<usize>, HashAlgId)> = Vec::new();
if let Some(commit_config) = payload.transcript_commit.as_ref() {
encoding_idxs = commit_config.iter_encoding().cloned().collect();
hash_idxs = commit_config.iter_hash().cloned().collect();
}
let authenticator = Authenticator::new(
encoding_idxs.iter(),
hash_idxs.iter(),
payload.transcript.as_ref(),
);
let encoding = EncodingCreator::new(None, encoding_idxs);
let hasher = PlaintextHasher::new(hash_idxs.iter());
Self {
partial: payload.transcript,
server_identity: payload.handshake,
verified_server_name,
authenticator,
encoding,
encodings_transferred,
hasher,
tls_transcript: transcript,
transcript_refs,
}
}
/// Proves the transcript and generates the prover output.
///
/// Returns the output for the prover and if the encoding protocol has been
/// executed.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `ctx` - The thread context.
/// * `zk_aes_sent` - ZkAes for the sent traffic.
/// * `zk_aes_recv` - ZkAes for the received traffic.
/// * `keys` - The TLS session keys.
pub(crate) async fn prove(
mut self,
vm: &mut (impl EncodingVm<Binary> + Send),
ctx: &mut Context,
zk_aes_sent: &mut ZkAesCtr,
zk_aes_recv: &mut ZkAesCtr,
keys: SessionKeys,
) -> Result<(ProverOutput, bool), CommitError> {
// Authenticates only necessary parts of the transcript. Proof is not needed on
// the prover side.
let _ = self.authenticator.auth_sent(
vm,
zk_aes_sent,
self.tls_transcript,
self.transcript_refs,
)?;
let _ = self.authenticator.auth_recv(
vm,
zk_aes_recv,
self.tls_transcript,
self.transcript_refs,
)?;
vm.execute_all(ctx).await?;
// Decodes the transcript parts that should be disclosed.
if self.has_decoding_ranges() {
decode_transcript(
vm,
keys.server_write_key,
keys.server_write_iv,
self.authenticator.decoding(),
self.transcript_refs,
)?;
}
let mut output = ProverOutput {
transcript_commitments: Vec::new(),
transcript_secrets: Vec::new(),
};
// Creates encoding commitments if necessary.
if self.has_encoding_ranges() {
let (commitment, secret) = self.receive_encodings(vm, ctx).await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(secret));
}
// Creates hash commitments if necessary.
let hash_output = if self.has_hash_ranges() {
Some(self.hasher.prove(vm, self.transcript_refs)?)
} else {
None
};
vm.execute_all(ctx).await?;
if let Some((commitments, secrets)) = hash_output {
let commitments = commitments.try_recv()?;
for (hash, secret) in commitments.into_iter().zip(secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(hash));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok((output, self.encodings_transferred))
}
/// Verifies the transcript and generates the verifier output.
///
/// Returns the output for the verifier and if the encoding protocol has
/// been executed.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `ctx` - The thread context.
/// * `zk_aes_sent` - ZkAes for the sent traffic.
/// * `zk_aes_recv` - ZkAes for the received traffic.
/// * `keys` - The TLS session keys.
/// * `delta` - The delta.
/// * `certs` - The certificate chain.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn verify(
mut self,
vm: &mut (impl EncodingVm<Binary> + Send),
ctx: &mut Context,
zk_aes_sent: &mut ZkAesCtr,
zk_aes_recv: &mut ZkAesCtr,
keys: SessionKeys,
delta: Delta,
certs: Option<&RootCertStore>,
) -> Result<(VerifierOutput, bool), CommitError> {
self.verify_server_identity(certs)?;
// Authenticate only necessary parts of the transcript.
let sent_proof = self.authenticator.auth_sent(
vm,
zk_aes_sent,
self.tls_transcript,
self.transcript_refs,
)?;
let recv_proof = self.authenticator.auth_recv(
vm,
zk_aes_recv,
self.tls_transcript,
self.transcript_refs,
)?;
vm.execute_all(ctx).await?;
// Verify the plaintext proofs.
sent_proof.verify()?;
recv_proof.verify()?;
// Decodes the transcript parts that should be disclosed and checks the
// transcript length.
if self.has_decoding_ranges() {
check_transcript_length(self.partial.as_ref(), self.tls_transcript)?;
decode_transcript(
vm,
keys.server_write_key,
keys.server_write_iv,
self.authenticator.decoding(),
self.transcript_refs,
)?;
}
let mut output = VerifierOutput {
server_name: None,
transcript: None,
transcript_commitments: Vec::new(),
};
// Creates encoding commitments if necessary.
if self.has_encoding_ranges() {
let commitment = self.transfer_encodings(vm, ctx, delta).await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
}
// Create hash commitments if necessary.
let hash_output = if self.has_hash_ranges() {
Some(self.hasher.verify(vm, self.transcript_refs)?)
} else {
None
};
vm.execute_all(ctx).await?;
if let Some(commitments) = hash_output {
let commitments = commitments.try_recv()?;
for hash in commitments.into_iter() {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(hash));
}
}
// Verify revealed data.
if self.has_decoding_ranges() {
verify_transcript(
vm,
keys.server_write_key,
keys.server_write_iv,
self.authenticator.decoding(),
self.partial.as_ref(),
self.transcript_refs,
self.tls_transcript,
)?;
}
output.transcript = self.partial;
output.server_name = self.verified_server_name;
Ok((output, self.encodings_transferred))
}
/// Checks the server identity.
///
/// # Arguments
///
/// * `root_store` - Contains root certificates.
fn verify_server_identity(
&mut self,
root_store: Option<&RootCertStore>,
) -> Result<(), CommitError> {
if !self.has_server_identity() || self.verified_server_name.is_some() {
return Ok(());
}
let Some((server_name, handshake_data)) = self.server_identity.as_ref() else {
return Err(CommitError(ErrorRepr::MissingCertChain));
};
let verifier = if let Some(root_store) = root_store {
ServerCertVerifier::new(root_store)?
} else {
ServerCertVerifier::mozilla()
};
let time = self.tls_transcript.time();
let ephemeral_key = self.tls_transcript.server_ephemeral_key();
handshake_data.verify(&verifier, time, ephemeral_key, server_name)?;
self.verified_server_name = Some(server_name.clone());
Ok(())
}
/// Compute the encoding adjustments to send to the prover.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `ctx` - The thread context.
/// * `delta` - The Delta.
async fn transfer_encodings(
&mut self,
vm: &mut (dyn EncodingMemory<Binary> + Send),
ctx: &mut Context,
delta: Delta,
) -> Result<EncodingCommitment, CommitError> {
if self.encodings_transferred {
return Err(CommitError(ErrorRepr::EncodingOnlyOnce));
}
self.encodings_transferred = true;
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
let encodings = self.encoding.transfer(vm, secret, self.transcript_refs)?;
let frame_limit = self.encoding_size().saturating_add(ctx.io().limit());
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
let root: TypedHash = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret).await?;
let commitment = EncodingCommitment { root, secret };
Ok(commitment)
}
/// Receive the encoding adjustments from the verifier and adjust the prover
/// encodings.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `ctx` - The thread context.
async fn receive_encodings(
&mut self,
vm: &mut (dyn EncodingMemory<Binary> + Send),
ctx: &mut Context,
) -> Result<(EncodingCommitment, EncodingTree), CommitError> {
if self.encodings_transferred {
return Err(CommitError(ErrorRepr::EncodingOnlyOnce));
}
self.encodings_transferred = true;
let frame_limit = self.encoding_size().saturating_add(ctx.io().limit());
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let (root, tree) = self.encoding.receive(vm, encodings, self.transcript_refs)?;
ctx.io_mut().send(root).await?;
let secret: EncoderSecret = ctx.io_mut().expect_next().await?;
let commitment = EncodingCommitment { root, secret };
Ok((commitment, tree))
}
/// Returns the size of the encodings in bytes.
fn encoding_size(&self) -> usize {
let (sent, recv) = self.authenticator.encoding();
ENCODING_SIZE * (sent.len() + recv.len())
}
/// Returns if there are encoding ranges present.
fn has_encoding_ranges(&self) -> bool {
let (sent, recv) = self.authenticator.encoding();
!sent.is_empty() || !recv.is_empty()
}
/// Returns if there are hash ranges present.
fn has_hash_ranges(&self) -> bool {
let (sent, recv) = self.authenticator.hash();
!sent.is_empty() || !recv.is_empty()
}
/// Returns if there are decoding ranges present.
fn has_decoding_ranges(&self) -> bool {
let (sent, recv) = self.authenticator.decoding();
!sent.is_empty() || !recv.is_empty()
}
/// Returns if there is a server identity present.
fn has_server_identity(&self) -> bool {
self.server_identity.is_some()
}
}
/// Error for commitments.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct CommitError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("commit error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(VmError),
#[error("IO error: {0}")]
Io(std::io::Error),
#[error("hash commit error: {0}")]
Hash(HashCommitError),
#[error("encoding error: {0}")]
Encoding(EncodingError),
#[error("encoding commitments can be created only once")]
EncodingOnlyOnce,
#[error("decode error: {0}")]
Decode(DecodeError),
#[error("authentication error: {0}")]
Auth(AuthError),
#[error("cert chain missing for verifying server identity")]
MissingCertChain,
#[error("failed to verify server name")]
VerifyServerName(HandshakeVerificationError),
#[error("cert verifier error: {0}")]
CertVerifier(ServerCertVerifierError),
}
impl From<VmError> for CommitError {
fn from(err: VmError) -> Self {
Self(ErrorRepr::Vm(err))
}
}
impl From<std::io::Error> for CommitError {
fn from(err: std::io::Error) -> Self {
Self(ErrorRepr::Io(err))
}
}
impl From<AuthError> for CommitError {
fn from(value: AuthError) -> Self {
CommitError(ErrorRepr::Auth(value))
}
}
impl From<EncodingError> for CommitError {
fn from(value: EncodingError) -> Self {
CommitError(ErrorRepr::Encoding(value))
}
}
impl From<DecodeError> for CommitError {
fn from(value: DecodeError) -> Self {
CommitError(ErrorRepr::Decode(value))
}
}
impl From<HashCommitError> for CommitError {
fn from(value: HashCommitError) -> Self {
CommitError(ErrorRepr::Hash(value))
}
}
impl From<ServerCertVerifierError> for CommitError {
fn from(value: ServerCertVerifierError) -> Self {
CommitError(ErrorRepr::CertVerifier(value))
}
}
impl From<HandshakeVerificationError> for CommitError {
fn from(value: HandshakeVerificationError) -> Self {
CommitError(ErrorRepr::VerifyServerName(value))
}
}
#[cfg(test)]
mod tests {
use lipsum::{LIBER_PRIMUS, lipsum};
use mpc_tls::SessionKeys;
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::ideal_rcot;
use mpz_vm_core::Execute;
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
ProveConfig, ProveRequest,
connection::{HandshakeData, ServerName},
fixtures::transcript::{IV, KEY},
hash::HashAlgId,
transcript::{
ContentType, Direction, TlsTranscript, TranscriptCommitConfig, TranscriptCommitment,
TranscriptCommitmentKind, TranscriptSecret,
},
};
use crate::{
Role,
commit::{ProvingState, encoding::EncodingVm, transcript::TranscriptRefs},
zk_aes_ctr::ZkAesCtr,
};
#[tokio::main]
#[rstest]
async fn test_commit(
tls_transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
prove_config: ProveConfig,
prove_payload: ProveRequest,
) {
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let mut prover = Prover::new(ProverConfig::default(), ot_recv);
let mut verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let keys_prover = set_keys(&mut prover, KEY, IV, Role::Prover);
// not needed
let mac_key_prover = prover.alloc().unwrap();
prover.mark_public(mac_key_prover).unwrap();
prover.assign(mac_key_prover, [0_u8; 16]).unwrap();
prover.commit(mac_key_prover).unwrap();
let session_keys_prover = SessionKeys {
client_write_key: keys_prover.0,
client_write_iv: keys_prover.1,
server_write_key: keys_prover.0,
server_write_iv: keys_prover.1,
server_write_mac_key: mac_key_prover,
};
let keys_verifier = set_keys(&mut verifier, KEY, IV, Role::Verifier);
// not needed
let mac_key_verifier = verifier.alloc().unwrap();
verifier.mark_public(mac_key_verifier).unwrap();
verifier.assign(mac_key_verifier, [0_u8; 16]).unwrap();
verifier.commit(mac_key_verifier).unwrap();
let session_keys_verifier = SessionKeys {
client_write_key: keys_verifier.0,
client_write_iv: keys_verifier.1,
server_write_key: keys_verifier.0,
server_write_iv: keys_verifier.1,
server_write_mac_key: mac_key_verifier,
};
let transcript = tls_transcript.to_transcript().unwrap();
let prover_state = ProvingState::for_prover(
prove_config,
&tls_transcript,
&transcript,
&mut refs_prover,
false,
);
let mut zk_prover_sent = ZkAesCtr::new(Role::Prover);
zk_prover_sent.set_key(keys_prover.0, keys_prover.1);
zk_prover_sent.alloc(&mut prover, SENT_LEN).unwrap();
let mut zk_prover_recv = ZkAesCtr::new(Role::Prover);
zk_prover_recv.set_key(keys_prover.0, keys_prover.1);
zk_prover_recv.alloc(&mut prover, RECV_LEN).unwrap();
let verifier_state = ProvingState::for_verifier(
prove_payload,
&tls_transcript,
&mut refs_verifier,
None,
false,
);
let mut zk_verifier_sent = ZkAesCtr::new(Role::Verifier);
zk_verifier_sent.set_key(keys_verifier.0, keys_verifier.1);
zk_verifier_sent.alloc(&mut verifier, SENT_LEN).unwrap();
let mut zk_verifier_recv = ZkAesCtr::new(Role::Verifier);
zk_verifier_recv.set_key(keys_verifier.0, keys_verifier.1);
zk_verifier_recv.alloc(&mut verifier, RECV_LEN).unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let ((prover_output, _), (verifier_output, _)) = tokio::try_join!(
prover_state.prove(
&mut prover,
&mut ctx_p,
&mut zk_prover_sent,
&mut zk_prover_recv,
session_keys_prover
),
verifier_state.verify(
&mut verifier,
&mut ctx_v,
&mut zk_verifier_sent,
&mut zk_verifier_recv,
session_keys_verifier,
delta,
None
)
)
.unwrap();
let prover_commitments = prover_output.transcript_commitments;
let prover_secrets = prover_output.transcript_secrets;
let verifier_commitments = verifier_output.transcript_commitments;
let verifier_server = verifier_output.server_name;
let partial = verifier_output.transcript;
prover_commitments
.iter()
.any(|commitment| matches!(commitment, TranscriptCommitment::Encoding(_)));
prover_commitments
.iter()
.any(|commitment| matches!(commitment, TranscriptCommitment::Hash(_)));
prover_secrets
.iter()
.any(|secret| matches!(secret, TranscriptSecret::Encoding(_)));
prover_secrets
.iter()
.any(|secret| matches!(secret, TranscriptSecret::Hash(_)));
verifier_commitments
.iter()
.any(|commitment| matches!(commitment, TranscriptCommitment::Encoding(_)));
verifier_commitments
.iter()
.any(|commitment| matches!(commitment, TranscriptCommitment::Hash(_)));
assert!(verifier_server.is_some());
assert!(partial.is_some());
}
#[fixture]
fn prove_config(
decoding: (RangeSet<usize>, RangeSet<usize>),
encoding: (RangeSet<usize>, RangeSet<usize>),
hash: (RangeSet<usize>, RangeSet<usize>),
tls_transcript: TlsTranscript,
) -> ProveConfig {
let transcript = tls_transcript.to_transcript().unwrap();
let mut builder = ProveConfig::builder(&transcript);
builder.reveal_sent(&decoding.0).unwrap();
builder.reveal_recv(&decoding.1).unwrap();
builder.server_identity();
let mut transcript_commit = TranscriptCommitConfig::builder(&transcript);
transcript_commit.encoding_hash_alg(HashAlgId::SHA256);
transcript_commit
.commit_with_kind(
&encoding.0,
Direction::Sent,
TranscriptCommitmentKind::Encoding,
)
.unwrap();
transcript_commit
.commit_with_kind(
&encoding.1,
Direction::Received,
TranscriptCommitmentKind::Encoding,
)
.unwrap();
transcript_commit
.commit_with_kind(
&hash.0,
Direction::Sent,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
)
.unwrap();
transcript_commit
.commit_with_kind(
&hash.1,
Direction::Received,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
)
.unwrap();
let transcript_commit = transcript_commit.build().unwrap();
builder.transcript_commit(transcript_commit);
builder.build().unwrap()
}
#[fixture]
fn prove_payload(
prove_config: ProveConfig,
tls_transcript: TlsTranscript,
decoding: (RangeSet<usize>, RangeSet<usize>),
) -> ProveRequest {
let (sent, recv) = decoding;
let handshake = HandshakeData::new(&tls_transcript);
let server_name = ServerName::Dns("tlsnotary.org".try_into().unwrap());
let partial = tls_transcript
.to_transcript()
.unwrap()
.to_partial(sent, recv);
ProveRequest::new(&prove_config, Some(partial), Some((server_name, handshake)))
}
fn set_keys(
vm: &mut dyn EncodingVm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(600..1100));
sent.union_mut(&(3450..3451));
recv.union_mut(&(200..405));
recv.union_mut(&(3182..4190));
(sent, recv)
}
#[fixture]
fn encoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(804..2100));
sent.union_mut(&(3000..3910));
recv.union_mut(&(0..1432));
recv.union_mut(&(2000..2100));
(sent, recv)
}
#[fixture]
fn hash() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(100..2100));
recv.union_mut(&(720..930));
(sent, recv)
}
#[fixture]
fn tls_transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(tls_transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
TranscriptRefs::new(sent_len, recv_len)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,708 @@
//! Authentication of the transcript plaintext and creation of the transcript
//! references.
use std::ops::Range;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeError, DecodeFutureTyped, MemoryExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Disjoint, RangeSet, Union, UnionMut};
use tlsn_core::{
hash::HashAlgId,
transcript::{ContentType, Direction, PartialTranscript, Record, TlsTranscript},
};
use crate::{
Role,
commit::transcript::TranscriptRefs,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Transcript Authenticator.
pub(crate) struct Authenticator {
encoding: Index,
hash: Index,
decoding: Index,
proving: Index,
}
impl Authenticator {
/// Creates a new authenticator.
///
/// # Arguments
///
/// * `encoding` - Ranges for encoding commitments.
/// * `hash` - Ranges for hash commitments.
/// * `partial` - The partial transcript.
pub(crate) fn new<'a>(
encoding: impl Iterator<Item = &'a (Direction, RangeSet<usize>)>,
hash: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
partial: Option<&PartialTranscript>,
) -> Self {
// Compute encoding index.
let mut encoding_sent = RangeSet::default();
let mut encoding_recv = RangeSet::default();
for (d, idx) in encoding {
match d {
Direction::Sent => encoding_sent.union_mut(idx),
Direction::Received => encoding_recv.union_mut(idx),
}
}
let encoding = Index::new(encoding_sent, encoding_recv);
// Compute hash index.
let mut hash_sent = RangeSet::default();
let mut hash_recv = RangeSet::default();
for (d, idx, _) in hash {
match d {
Direction::Sent => hash_sent.union_mut(idx),
Direction::Received => hash_recv.union_mut(idx),
}
}
let hash = Index {
sent: hash_sent,
recv: hash_recv,
};
// Compute decoding index.
let mut decoding_sent = RangeSet::default();
let mut decoding_recv = RangeSet::default();
if let Some(partial) = partial {
decoding_sent.union_mut(partial.sent_authed());
decoding_recv.union_mut(partial.received_authed());
}
let decoding = Index::new(decoding_sent, decoding_recv);
// Compute proving index.
let mut proving_sent = RangeSet::default();
let mut proving_recv = RangeSet::default();
proving_sent.union_mut(decoding.sent());
proving_sent.union_mut(encoding.sent());
proving_sent.union_mut(hash.sent());
proving_recv.union_mut(decoding.recv());
proving_recv.union_mut(encoding.recv());
proving_recv.union_mut(hash.recv());
let proving = Index::new(proving_sent, proving_recv);
Self {
encoding,
hash,
decoding,
proving,
}
}
/// Authenticates the sent plaintext, returning a proof of encryption and
/// writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_sent` - ZK AES Cipher for sent traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_sent(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_sent: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let missing_index = transcript_refs.compute_missing(Direction::Sent, self.proving.sent());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_sent,
Direction::Sent,
sent,
transcript_refs,
missing_index,
)
}
/// Authenticates the received plaintext, returning a proof of encryption
/// and writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_recv` - ZK AES Cipher for received traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_recv(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_recv: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let decoding_recv = self.decoding.recv();
let fully_decoded = decoding_recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
// If we only have decoding ranges, and the parts we are going to decode will
// complete to the full received transcript, then we do not need to
// authenticate, because this will be done by
// `crate::commit::decode::verify_transcript`, as it uses the server write
// key and iv for verification.
if decoding_recv == self.proving.recv() && fully_decoded == full_range {
return Ok(RecordProof::default());
}
let missing_index =
transcript_refs.compute_missing(Direction::Received, self.proving.recv());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_recv,
Direction::Received,
recv,
transcript_refs,
missing_index,
)
}
/// Returns the sent and received encoding ranges.
pub(crate) fn encoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.encoding.sent(), self.encoding.recv())
}
/// Returns the sent and received hash ranges.
pub(crate) fn hash(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.hash.sent(), self.hash.recv())
}
/// Returns the sent and received decoding ranges.
pub(crate) fn decoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.decoding.sent(), self.decoding.recv())
}
}
/// Authenticates parts of the transcript in zk.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes` - ZK AES Cipher.
/// * `direction` - The direction of the application data.
/// * `app_data` - The application data.
/// * `transcript_refs` - The transcript references.
/// * `missing_index` - The index which needs to be proven.
fn authenticate<'a>(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
direction: Direction,
app_data: impl Iterator<Item = &'a Record>,
transcript_refs: &mut TranscriptRefs,
missing_index: RangeSet<usize>,
) -> Result<RecordProof, AuthError> {
let mut record_idx = Range::default();
let mut ciphertexts = Vec::new();
for record in app_data {
let record_len = record.ciphertext.len();
record_idx.end += record_len;
if missing_index.is_disjoint(&record_idx) {
record_idx.start += record_len;
continue;
}
let (plaintext_ref, ciphertext_ref) =
zk_aes.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())?;
if let Role::Prover = zk_aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(AuthError(ErrorRepr::MissingPlainText));
};
vm.assign(plaintext_ref, plaintext).map_err(AuthError::vm)?;
}
vm.commit(plaintext_ref).map_err(AuthError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(AuthError::vm)?;
transcript_refs.add(direction, &record_idx, plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
record_idx.start += record_len;
}
let proof = RecordProof { ciphertexts };
Ok(proof)
}
#[derive(Debug, Clone, Default)]
struct Index {
sent: RangeSet<usize>,
recv: RangeSet<usize>,
}
impl Index {
fn new(sent: RangeSet<usize>, recv: RangeSet<usize>) -> Self {
Self { sent, recv }
}
fn sent(&self) -> &RangeSet<usize> {
&self.sent
}
fn recv(&self) -> &RangeSet<usize> {
&self.recv
}
}
/// Proof of encryption.
#[derive(Debug, Default)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), AuthError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(AuthError::vm)?
.ok_or(AuthError(ErrorRepr::MissingDecoding))?;
if ciphertext != expected {
return Err(AuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}
/// Error for [`Authenticator`].
#[derive(Debug, thiserror::Error)]
#[error("transcript authentication error: {0}")]
pub(crate) struct AuthError(#[source] ErrorRepr);
impl AuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk-aes error: {0}")]
ZkAes(ZkAesCtrError),
#[error("decode error: {0}")]
Decode(DecodeError),
#[error("plaintext is missing in record")]
MissingPlainText,
#[error("decoded value is missing")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
impl From<ZkAesCtrError> for AuthError {
fn from(value: ZkAesCtrError) -> Self {
Self(ErrorRepr::ZkAes(value))
}
}
impl From<DecodeError> for AuthError {
fn from(value: DecodeError) -> Self {
Self(ErrorRepr::Decode(value))
}
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
auth::{Authenticator, ErrorRepr},
transcript::TranscriptRefs,
},
zk_aes_ctr::ZkAesCtr,
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY, RECORD_SIZE},
hash::HashAlgId,
transcript::{ContentType, Direction, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_authenticator_sent(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_sent(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(600..1600));
prove_range.union_mut(&(800..2000));
prove_range.union_mut(&(2600..3700));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Sent), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Sent), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_recv(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, RECV_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, RECV_LEN).unwrap();
let _ = auth_prover
.auth_recv(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_recv(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(4000..4200));
prove_range.union_mut(&(5000..5800));
prove_range.union_mut(&(6800..RECV_LEN));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Received), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Received), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_sent_verify_fail(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
// Forge verifier transcript to check if verify fails.
// Use an index which is part of the proving range.
let forged = forged();
let proof = auth_verifier
.auth_sent(&mut verifier, &mut zk_verifier, &forged, &mut refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let err = proof.verify().unwrap_err();
assert!(matches!(err.0, ErrorRepr::InvalidCiphertext));
}
fn keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let sent = 600..1600;
let recv = 4000..4200;
(sent.into(), recv.into())
}
#[fixture]
fn encoding() -> Vec<(Direction, RangeSet<usize>)> {
let sent = 800..2000;
let recv = 5000..5800;
let encoding = vec![
(Direction::Sent, sent.into()),
(Direction::Received, recv.into()),
];
encoding
}
#[fixture]
fn hashes() -> Vec<(Direction, RangeSet<usize>, HashAlgId)> {
let sent = 2600..3700;
let recv = 6800..RECV_LEN;
let alg = HashAlgId::SHA256;
let hashes = vec![
(Direction::Sent, sent.into(), alg),
(Direction::Received, recv.into(), alg),
];
hashes
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 610;
let mut sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
sent[WRONG_BYTE_INDEX] = sent[WRONG_BYTE_INDEX].wrapping_add(1);
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
TranscriptRefs::new(sent_len, recv_len)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,615 @@
//! Selective disclosure.
use mpz_memory_core::{
Array, MemoryExt,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{Intersection, RangeSet, Subset, Union};
use tlsn_core::transcript::{ContentType, Direction, PartialTranscript, TlsTranscript};
use crate::commit::TranscriptRefs;
/// Decodes parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `transcript_refs` - The transcript references.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
transcript_refs: &mut TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding_ranges;
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for slice in sent_refs.into_iter() {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Sent, sent);
// If possible use server write key for decoding.
let fully_decoded = recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
if fully_decoded == full_range {
// Drop the future, we don't need it.
drop(vm.decode(key).map_err(DecodeError::vm)?);
drop(vm.decode(iv).map_err(DecodeError::vm)?);
transcript_refs.mark_decoded(Direction::Received, &full_range.into());
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for slice in recv_refs {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Received, recv);
}
Ok(())
}
/// Verifies parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `partial` - The partial transcript.
/// * `transcript_refs` - The transcript references.
/// * `transcript` - The TLS transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
partial: Option<&PartialTranscript>,
transcript_refs: &mut TranscriptRefs,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let (sent, recv) = decoding_ranges;
let mut authenticated_data = Vec::new();
// Add sent transcript parts.
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for data in sent_refs.into_iter() {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
// Add received transcript parts, if possible using key and iv.
if let (Some(key), Some(iv)) = (
vm.get(key).map_err(DecodeError::vm)?,
vm.get(iv).map_err(DecodeError::vm)?,
) {
let plaintext = verify_with_keys(key, iv, recv, transcript)?;
authenticated_data.extend_from_slice(&plaintext);
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for data in recv_refs {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in sent.iter_ranges() {
purported_data.extend_from_slice(&partial.sent_unsafe()[range]);
}
for range in recv.iter_ranges() {
purported_data.extend_from_slice(&partial.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(DecodeError(ErrorRepr::InconsistentTranscript));
}
Ok(())
}
/// Checks the transcript length.
///
/// # Arguments
///
/// * `partial` - The partial transcript.
/// * `transcript` - The TLS transcript.
pub(crate) fn check_transcript_length(
partial: Option<&PartialTranscript>,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let sent_len: usize = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len: usize = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
// Check ranges.
if partial.len_sent() != sent_len || partial.len_received() != recv_len {
return Err(DecodeError(ErrorRepr::VerifyTranscriptLength));
}
Ok(())
}
fn verify_with_keys(
key: [u8; 16],
iv: [u8; 4],
decoding_ranges: &RangeSet<usize>,
transcript: &TlsTranscript,
) -> Result<Vec<u8>, DecodeError> {
let mut plaintexts = Vec::with_capacity(decoding_ranges.len());
let mut position = 0_usize;
let recv_data = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv_data {
let current = position..position + record.ciphertext.len();
if !current.is_subset(decoding_ranges) {
position += record.ciphertext.len();
continue;
}
let nonce = record
.explicit_nonce
.clone()
.try_into()
.expect("explicit nonce should be 8 bytes");
let plaintext = aes_apply_keystream(key, iv, nonce, &record.ciphertext);
let record_decoding_range = decoding_ranges.intersection(&current);
for r in record_decoding_range.iter_ranges() {
let shifted = r.start - position..r.end - position;
plaintexts.extend_from_slice(&plaintext[shifted]);
}
position += record.ciphertext.len()
}
Ok(plaintexts)
}
fn aes_apply_keystream(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], msg: &[u8]) -> Vec<u8> {
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let start_ctr = 2;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
let mut out = msg.to_vec();
cipher
.try_seek(start_ctr * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut out);
out
}
/// A decoding error.
#[derive(Debug, thiserror::Error)]
#[error("decode error: {0}")]
pub(crate) struct DecodeError(#[source] ErrorRepr);
impl DecodeError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("missing partial transcript")]
MissingPartialTranscript,
#[error("length of partial transcript does not match expected length")]
VerifyTranscriptLength,
#[error("provided transcript does not match exptected")]
InconsistentTranscript,
#[error("trying to get plaintext, but it is missing")]
MissingPlaintext,
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
TranscriptRefs,
decode::{DecodeError, ErrorRepr, decode_transcript, verify_transcript},
},
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY},
transcript::{ContentType, Direction, PartialTranscript, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding.clone());
decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_fail(
decoding: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding.clone());
let err = decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
#[rstest]
#[tokio::test]
async fn test_decode_all(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding_full.clone());
decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_all_fail(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding_full.clone());
let err = decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
async fn decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
partial: PartialTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding;
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut transcript_refs_verifier = transcript_refs.clone();
let mut transcript_refs_prover = transcript_refs;
let key: [u8; 16] = KEY;
let iv: [u8; 4] = IV;
let (key_prover, iv_prover) = assign_keys(&mut prover, key, iv, Role::Prover);
let (key_verifier, iv_verifier) = assign_keys(&mut verifier, key, iv, Role::Verifier);
assign_transcript(
&mut prover,
Role::Prover,
&transcript,
&mut transcript_refs_prover,
);
assign_transcript(
&mut verifier,
Role::Verifier,
&transcript,
&mut transcript_refs_verifier,
);
decode_transcript(
&mut prover,
key_prover,
iv_prover,
(&sent, &recv),
&mut transcript_refs_prover,
)
.unwrap();
decode_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
&mut transcript_refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v),
)
.unwrap();
verify_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
Some(&partial),
&mut transcript_refs_verifier,
&transcript,
)
}
fn assign_keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
fn assign_transcript(
vm: &mut dyn Vm<Binary>,
role: Role,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) {
let mut pos = 0_usize;
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in sent {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Sent, &index, plaintext_ref);
pos += record.ciphertext.len();
}
pos = 0;
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Received, &index, plaintext_ref);
pos += record.ciphertext.len();
}
}
fn partial(
transcript: &TlsTranscript,
decoding: (RangeSet<usize>, RangeSet<usize>),
) -> PartialTranscript {
let (sent, recv) = decoding;
transcript.to_transcript().unwrap().to_partial(sent, recv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(600..1100));
sent.union_mut(&(3450..4000));
recv.union_mut(&(2000..3000));
recv.union_mut(&(4800..4900));
recv.union_mut(&(6000..7000));
(sent, recv)
}
#[fixture]
fn decoding_full(transcript: TlsTranscript) -> (RangeSet<usize>, RangeSet<usize>) {
let transcript = transcript.to_transcript().unwrap();
let (len_sent, len_recv) = transcript.len();
let sent = (0..len_sent).into();
let recv = (0..len_recv).into();
(sent, recv)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 2200;
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
recv[WRONG_BYTE_INDEX] = recv[WRONG_BYTE_INDEX].wrapping_add(1);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
TranscriptRefs::new(sent_len, recv_len)
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,530 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_memory_core::{
MemoryType, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{RangeSet, Subset, UnionMut};
use serde::{Deserialize, Serialize};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256, TypedHash},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingProvider, EncodingProviderError, EncodingTree,
EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
pub(crate) const ENCODING_SIZE: usize = 128;
pub(crate) trait EncodingVm<T: MemoryType>: EncodingMemory<T> + Vm<T> {}
impl<T: MemoryType, U> EncodingVm<T> for U where U: EncodingMemory<T> + Vm<T> {}
pub(crate) trait EncodingMemory<T: MemoryType> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8>;
}
impl<T> EncodingMemory<Binary> for mpz_zk::Prover<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let macs = self.get_macs(v).expect("macs should be available");
encodings.extend(macs.iter().flat_map(|mac| mac.as_bytes()));
}
encodings
}
}
impl<T> EncodingMemory<Binary> for mpz_zk::Verifier<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let keys = self.get_keys(v).expect("keys should be available");
encodings.extend(keys.iter().flat_map(|key| key.as_block().as_bytes()));
}
encodings
}
}
/// The encoding adjustments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Encodings {
pub(crate) sent: Vec<u8>,
pub(crate) recv: Vec<u8>,
}
/// Creates encoding commitments.
#[derive(Debug)]
pub(crate) struct EncodingCreator {
hash_id: Option<HashAlgId>,
sent: RangeSet<usize>,
recv: RangeSet<usize>,
idxs: Vec<(Direction, RangeSet<usize>)>,
}
impl EncodingCreator {
/// Creates a new encoding creator.
///
/// # Arguments
///
/// * `hash_id` - The id of the hash algorithm.
/// * `idxs` - The indices for encoding commitments.
pub(crate) fn new(hash_id: Option<HashAlgId>, idxs: Vec<(Direction, RangeSet<usize>)>) -> Self {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (direction, idx) in idxs.iter() {
for range in idx.iter_ranges() {
match direction {
Direction::Sent => sent.union_mut(&range),
Direction::Received => recv.union_mut(&range),
}
}
}
Self {
hash_id,
sent,
recv,
idxs,
}
}
/// Receives the encodings using the provided MACs from the encoding memory.
///
/// The MACs must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `encodings` - The encoding adjustments.
/// * `transcript_refs` - The transcript references.
pub(crate) fn receive(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
encodings: Encodings,
transcript_refs: &TranscriptRefs,
) -> Result<(TypedHash, EncodingTree), EncodingError> {
let Some(id) = self.hash_id else {
return Err(EncodingError(ErrorRepr::MissingHashId));
};
let hasher: &(dyn HashAlgorithm + Send + Sync) = match id {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(EncodingError(ErrorRepr::UnsupportedHashAlg(alg)));
}
};
let Encodings {
sent: mut sent_adjust,
recv: mut recv_adjust,
} = encodings;
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_adjust, &mut recv_adjust)?;
let provider = Provider::new(sent_adjust, &self.sent, recv_adjust, &self.recv);
let tree = EncodingTree::new(hasher, self.idxs.iter(), &provider)?;
let root = tree.root();
Ok((root, tree))
}
/// Transfers the encodings using the provided secret and keys from the
/// encoding memory.
///
/// The keys must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `secret` - The encoder secret.
/// * `transcript_refs` - The transcript references.
pub(crate) fn transfer(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
secret: EncoderSecret,
transcript_refs: &TranscriptRefs,
) -> Result<Encodings, EncodingError> {
let encoder = new_encoder(&secret);
let mut sent_zero = Vec::with_capacity(self.sent.len() * ENCODING_SIZE);
let mut recv_zero = Vec::with_capacity(self.recv.len() * ENCODING_SIZE);
for range in self.sent.iter_ranges() {
encoder.encode_range(Direction::Sent, range, &mut sent_zero);
}
for range in self.recv.iter_ranges() {
encoder.encode_range(Direction::Received, range, &mut recv_zero);
}
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_zero, &mut recv_zero)?;
let encodings = Encodings {
sent: sent_zero,
recv: recv_zero,
};
Ok(encodings)
}
}
/// Adjust encodings by transcript references.
///
/// # Arguments
///
/// * `sent` - The encodings for the sent bytes.
/// * `recv` - The encodings for the received bytes.
/// * `sent_adjust` - The adjustment bytes for the encodings of the sent bytes.
/// * `recv_adjust` - The adjustment bytes for the encodings of the received
/// bytes.
fn adjust(
sent: &[u8],
recv: &[u8],
sent_adjust: &mut [u8],
recv_adjust: &mut [u8],
) -> Result<(), EncodingError> {
assert_eq!(sent.len() % ENCODING_SIZE, 0);
assert_eq!(recv.len() % ENCODING_SIZE, 0);
if sent_adjust.len() != sent.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Sent,
expected: sent.len(),
got: sent_adjust.len(),
}
.into());
}
if recv_adjust.len() != recv.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Received,
expected: recv.len(),
got: recv_adjust.len(),
}
.into());
}
sent_adjust
.iter_mut()
.zip(sent)
.for_each(|(adjust, enc)| *adjust ^= enc);
recv_adjust
.iter_mut()
.zip(recv)
.for_each(|(adjust, enc)| *adjust ^= enc);
Ok(())
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
sent_range: RangeSet<usize>,
recv: Vec<u8>,
recv_range: RangeSet<usize>,
}
impl Provider {
fn new(
sent: Vec<u8>,
sent_range: &RangeSet<usize>,
recv: Vec<u8>,
recv_range: &RangeSet<usize>,
) -> Self {
assert_eq!(
sent.len(),
sent_range.len() * ENCODING_SIZE,
"length of sent encodings and their index length do not match"
);
assert_eq!(
recv.len(),
recv_range.len() * ENCODING_SIZE,
"length of received encodings and their index length do not match"
);
Self {
sent,
sent_range: sent_range.clone(),
recv,
recv_range: recv_range.clone(),
}
}
fn adjust(
&self,
direction: Direction,
range: &Range<usize>,
) -> Result<Range<usize>, EncodingProviderError> {
let internal_range = match direction {
Direction::Sent => &self.sent_range,
Direction::Received => &self.recv_range,
};
if !range.is_subset(internal_range) {
return Err(EncodingProviderError);
}
let shift = internal_range
.iter()
.take_while(|&el| el < range.start)
.count();
let translated = Range {
start: shift,
end: shift + range.len(),
};
Ok(translated)
}
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let range = self.adjust(direction, &range)?;
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("incorrect adjustment count for {direction}: expected {expected}, got {got}")]
IncorrectAdjustCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("missing hash id")]
MissingHashId,
#[error("unsupported hash algorithm for encoding commitment: {0}")]
UnsupportedHashAlg(HashAlgId),
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
use crate::commit::{
encoding::{ENCODING_SIZE, EncodingCreator, Encodings, Provider},
transcript::TranscriptRefs,
};
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_memory_core::{
FromRaw, Slice, ToRaw, Vector,
binary::{Binary, U8},
};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
hash::{HashAlgId, HashProvider},
transcript::{
Direction,
encoding::{EncoderSecret, EncodingCommitment, EncodingProvider},
},
};
#[rstest]
fn test_encoding_adjust(
index: (RangeSet<usize>, RangeSet<usize>),
transcript_refs: TranscriptRefs,
encoding_idxs: Vec<(Direction, RangeSet<usize>)>,
) {
let creator = EncodingCreator::new(Some(HashAlgId::SHA256), encoding_idxs);
let mock_memory = MockEncodingMemory;
let delta = Delta::new(Block::ONES);
let seed = [1_u8; 32];
let secret = EncoderSecret::new(seed, delta.as_block().to_bytes());
let adjustments = creator
.transfer(&mock_memory, secret, &transcript_refs)
.unwrap();
let (root, tree) = creator
.receive(&mock_memory, adjustments, &transcript_refs)
.unwrap();
// Check correctness of encoding protocol.
let mut idxs = Vec::new();
let (sent_range, recv_range) = index;
idxs.push((Direction::Sent, sent_range.clone()));
idxs.push((Direction::Received, recv_range.clone()));
let commitment = EncodingCommitment { root, secret };
let proof = tree.proof(idxs.iter()).unwrap();
// Here we set the trancscript plaintext to just be zeroes, which is possible
// because it is not determined by the encodings so far.
let sent = vec![0_u8; transcript_refs.max_len(Direction::Sent)];
let recv = vec![0_u8; transcript_refs.max_len(Direction::Received)];
let (idx_sent, idx_recv) = proof
.verify_with_provider(&HashProvider::default(), &commitment, &sent, &recv)
.unwrap();
assert_eq!(idx_sent, idxs[0].1);
assert_eq!(idx_recv, idxs[1].1);
}
#[rstest]
fn test_encoding_provider(index: (RangeSet<usize>, RangeSet<usize>), encodings: Encodings) {
let (sent_range, recv_range) = index;
let Encodings { sent, recv } = encodings;
let provider = Provider::new(sent, &sent_range, recv, &recv_range);
let mut encodings_sent = Vec::new();
let mut encodings_recv = Vec::new();
provider
.provide_encoding(Direction::Sent, 16..24, &mut encodings_sent)
.unwrap();
provider
.provide_encoding(Direction::Received, 56..64, &mut encodings_recv)
.unwrap();
let expected_sent = generate_encodings((16..24).into());
let expected_recv = generate_encodings((56..64).into());
assert_eq!(expected_sent, encodings_sent);
assert_eq!(expected_recv, encodings_recv);
}
#[fixture]
fn transcript_refs(index: (RangeSet<usize>, RangeSet<usize>)) -> TranscriptRefs {
let mut transcript_refs = TranscriptRefs::new(40, 64);
let dummy = |range: Range<usize>| {
Vector::<U8>::from_raw(Slice::from_range_unchecked(8 * range.start..8 * range.end))
};
for range in index.0.iter_ranges() {
transcript_refs.add(Direction::Sent, &range, dummy(range.clone()));
}
for range in index.1.iter_ranges() {
transcript_refs.add(Direction::Received, &range, dummy(range.clone()));
}
transcript_refs
}
#[fixture]
fn encodings(index: (RangeSet<usize>, RangeSet<usize>)) -> Encodings {
let sent = generate_encodings(index.0);
let recv = generate_encodings(index.1);
Encodings { sent, recv }
}
#[fixture]
fn encoding_idxs(
index: (RangeSet<usize>, RangeSet<usize>),
) -> Vec<(Direction, RangeSet<usize>)> {
let (sent, recv) = index;
vec![(Direction::Sent, sent), (Direction::Received, recv)]
}
#[fixture]
fn index() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
sent.union_mut(&(0..8));
sent.union_mut(&(16..24));
sent.union_mut(&(32..40));
let mut recv = RangeSet::default();
recv.union_mut(&(40..48));
recv.union_mut(&(56..64));
(sent, recv)
}
#[derive(Clone, Copy)]
struct MockEncodingMemory;
impl EncodingMemory<Binary> for MockEncodingMemory {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let ranges: Vec<Range<usize>> = values
.iter()
.map(|r| {
let range = r.to_raw().to_range();
range.start / 8..range.end / 8
})
.collect();
let ranges: RangeSet<usize> = ranges.into();
generate_encodings(ranges)
}
}
fn generate_encodings(index: RangeSet<usize>) -> Vec<u8> {
let mut out = Vec::new();
for el in index.iter() {
out.extend_from_slice(&[el as u8; ENCODING_SIZE]);
}
out
}
}

View File

@@ -0,0 +1,404 @@
//! Plaintext hash commitments.
use std::collections::HashMap;
use mpz_core::bitvec::BitVec;
use mpz_hash::sha256::Sha256;
use mpz_memory_core::{
DecodeFutureTyped, MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
Direction,
hash::{PlaintextHash, PlaintextHashSecret},
},
};
use crate::{Role, commit::transcript::TranscriptRefs};
/// Creates plaintext hashes.
#[derive(Debug)]
pub(crate) struct PlaintextHasher {
ranges: Vec<HashRange>,
}
impl PlaintextHasher {
/// Creates a new instance.
///
/// # Arguments
///
/// * `indices` - The hash indices.
pub(crate) fn new<'a>(
indices: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
) -> Self {
let mut ranges = Vec::new();
for (direction, index, id) in indices {
let hash_range = HashRange::new(*direction, index.clone(), *id);
ranges.push(hash_range);
}
Self { ranges }
}
/// Prove plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn prove(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<(HashFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Prover, transcript_refs)?;
let mut futures = Vec::new();
let mut secrets = Vec::new();
for ((range, hash_ref), blinder_ref) in self.ranges.iter().zip(hash_refs).zip(blinders) {
let blinder: Blinder = rand::random();
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut));
secrets.push(PlaintextHashSecret {
direction: range.direction,
idx: range.range.clone(),
blinder,
alg: range.id,
});
}
let hashes = HashFuture { futures };
Ok((hashes, secrets))
}
/// Verify plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn verify(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<HashFuture, HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Verifier, transcript_refs)?;
let mut futures = Vec::new();
for ((range, hash_ref), blinder) in self.ranges.iter().zip(hash_refs).zip(blinders) {
vm.commit(blinder)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut))
}
let hashes = HashFuture { futures };
Ok(hashes)
}
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn commit(
vm: &mut dyn Vm<Binary>,
ranges: &[HashRange],
role: Role,
refs: &TranscriptRefs,
) -> Result<(Vec<Array<U8, 32>>, Vec<Vector<U8>>), HashCommitError> {
let mut hashers = HashMap::new();
let mut hash_refs = Vec::new();
let mut blinders = Vec::new();
for HashRange {
direction,
range,
id,
} in ranges.iter()
{
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
}
let hash = match *id {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(id).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(id, hasher.clone());
hasher
};
for plaintext in refs.get(*direction, range) {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
id => {
return Err(HashCommitError::unsupported_alg(id));
}
};
hash_refs.push(hash);
blinders.push(blinder);
}
Ok((hash_refs, blinders))
}
/// Future which will resolve to the committed hash values.
#[derive(Debug)]
pub(crate) struct HashFuture {
futures: Vec<(HashRange, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl HashFuture {
/// Tries to receive the value, returning an error if the value is not
/// ready.
pub(crate) fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
let mut output = Vec::new();
for (hash_range, mut fut) in self.futures {
let hash = fut
.try_recv()
.map_err(|_| HashCommitError::decode())?
.ok_or_else(HashCommitError::decode)?;
output.push(PlaintextHash {
direction: hash_range.direction,
idx: hash_range.range,
hash: TypedHash {
alg: hash_range.id,
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
},
});
}
Ok(output)
}
}
#[derive(Debug, Clone)]
struct HashRange {
direction: Direction,
range: RangeSet<usize>,
id: HashAlgId,
}
impl HashRange {
fn new(direction: Direction, range: RangeSet<usize>, id: HashAlgId) -> Self {
Self {
direction,
range,
id,
}
}
}
/// Error type for hash commitments.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct HashCommitError(#[from] ErrorRepr);
impl HashCommitError {
fn decode() -> Self {
Self(ErrorRepr::Decode)
}
fn convert(e: &'static str) -> Self {
Self(ErrorRepr::Convert(e))
}
fn hasher<E>(e: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self(ErrorRepr::Hasher(e.into()))
}
fn unsupported_alg(alg: HashAlgId) -> Self {
Self(ErrorRepr::UnsupportedAlg { alg })
}
}
#[derive(Debug, thiserror::Error)]
#[error("hash commit error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(VmError),
#[error("failed to decode hash")]
Decode,
#[error("failed to convert hash: {0}")]
Convert(&'static str),
#[error("unsupported hash algorithm: {alg}")]
UnsupportedAlg { alg: HashAlgId },
#[error("hasher error: {0}")]
Hasher(Box<dyn std::error::Error + Send + Sync>),
}
impl From<VmError> for HashCommitError {
fn from(value: VmError) -> Self {
Self(ErrorRepr::Vm(value))
}
}
#[cfg(test)]
mod test {
use crate::{
Role,
commit::{hash::PlaintextHasher, transcript::TranscriptRefs},
};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use sha2::Digest;
use tlsn_core::{hash::HashAlgId, transcript::Direction};
#[rstest]
#[tokio::test]
async fn test_hasher() {
let mut sent1 = RangeSet::default();
sent1.union_mut(&(1..6));
sent1.union_mut(&(11..16));
let mut sent2 = RangeSet::default();
sent2.union_mut(&(22..25));
let mut recv = RangeSet::default();
recv.union_mut(&(20..25));
let hash_ranges = [
(Direction::Sent, sent1, HashAlgId::SHA256),
(Direction::Sent, sent2, HashAlgId::SHA256),
(Direction::Received, recv, HashAlgId::SHA256),
];
let mut refs_prover = TranscriptRefs::new(1000, 1000);
let mut refs_verifier = TranscriptRefs::new(1000, 1000);
let values = [
b"abcde".to_vec(),
b"vwxyz".to_vec(),
b"xxx".to_vec(),
b"12345".to_vec(),
];
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut values_iter = values.iter();
for (direction, idx, _) in hash_ranges.iter() {
for range in idx.iter_ranges() {
let value = values_iter.next().unwrap();
let ref_prover = assign(Role::Prover, &mut prover, value.clone());
refs_prover.add(*direction, &range, ref_prover);
let ref_verifier = assign(Role::Verifier, &mut verifier, value.clone());
refs_verifier.add(*direction, &range, ref_verifier);
}
}
let hasher_prover = PlaintextHasher::new(hash_ranges.iter());
let hasher_verifier = PlaintextHasher::new(hash_ranges.iter());
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let (prover_hashes, prover_secrets) =
hasher_prover.prove(&mut prover, &refs_prover).unwrap();
let verifier_hashes = hasher_verifier
.verify(&mut verifier, &refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let prover_hashes = prover_hashes.try_recv().unwrap();
let verifier_hashes = verifier_hashes.try_recv().unwrap();
assert_eq!(prover_hashes, verifier_hashes);
let values_per_commitment = [b"abcdevwxyz".to_vec(), b"xxx".to_vec(), b"12345".to_vec()];
for ((value, hash), secret) in values_per_commitment
.iter()
.zip(prover_hashes)
.zip(prover_secrets)
{
let blinder = secret.blinder.as_bytes();
let mut blinded_value = value.clone();
blinded_value.extend_from_slice(blinder);
let expected_hash = sha256(&blinded_value);
let hash: Vec<u8> = hash.hash.value.into();
assert_eq!(expected_hash, hash);
}
}
fn assign(role: Role, vm: &mut dyn Vm<Binary>, value: Vec<u8>) -> Vector<U8> {
let reference: Vector<U8> = vm.alloc_vec(value.len()).unwrap();
if let Role::Prover = role {
vm.mark_private(reference).unwrap();
vm.assign(reference, value).unwrap();
} else {
vm.mark_blind(reference).unwrap();
}
vm.commit(reference).unwrap();
reference
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = sha2::Sha256::default();
hasher.update(data);
hasher.finalize().as_slice().to_vec()
}
#[fixture]
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
}

View File

@@ -0,0 +1,473 @@
//! Transcript reference storage.
use std::ops::Range;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{Difference, Disjoint, RangeSet, Subset, UnionMut};
use tlsn_core::transcript::Direction;
/// References to the application plaintext in the transcript.
#[derive(Debug, Clone)]
pub(crate) struct TranscriptRefs {
sent: RefStorage,
recv: RefStorage,
}
impl TranscriptRefs {
/// Creates a new instance.
///
/// # Arguments
///
/// `sent_max_len` - The maximum length of the sent transcript in bytes.
/// `recv_max_len` - The maximum length of the received transcript in bytes.
pub(crate) fn new(sent_max_len: usize, recv_max_len: usize) -> Self {
let sent = RefStorage::new(sent_max_len);
let recv = RefStorage::new(recv_max_len);
Self { sent, recv }
}
/// Adds new references to the transcript refs.
///
/// New transcript references are only added if none of them are already
/// present.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
/// * `refs` - The new transcript refs.
pub(crate) fn add(&mut self, direction: Direction, index: &Range<usize>, refs: Vector<U8>) {
match direction {
Direction::Sent => self.sent.add(index, refs),
Direction::Received => self.recv.add(index, refs),
}
}
/// Marks references of the transcript as decoded.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn mark_decoded(&mut self, direction: Direction, index: &RangeSet<usize>) {
match direction {
Direction::Sent => self.sent.mark_decoded(index),
Direction::Received => self.recv.mark_decoded(index),
}
}
/// Returns plaintext references for some index.
///
/// Queries that cannot or only partially be satisfied will return an empty
/// vector.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn get(&self, direction: Direction, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
match direction {
Direction::Sent => self.sent.get(index),
Direction::Received => self.recv.get(index),
}
}
/// Computes the subset of `index` which is missing.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn compute_missing(
&self,
direction: Direction,
index: &RangeSet<usize>,
) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.compute_missing(index),
Direction::Received => self.recv.compute_missing(index),
}
}
/// Returns the maximum length of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn max_len(&self, direction: Direction) -> usize {
match direction {
Direction::Sent => self.sent.max_len(),
Direction::Received => self.recv.max_len(),
}
}
/// Returns the decoded ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn decoded(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.decoded(),
Direction::Received => self.recv.decoded(),
}
}
/// Returns the set ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
#[cfg(test)]
pub(crate) fn index(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.index(),
Direction::Received => self.recv.index(),
}
}
}
/// Inner storage for transcript references.
///
/// Saves transcript references by maintaining an `index` and an `offset`. The
/// offset translates from `index` to some memory location and contains
/// information about possibly non-contigious memory locations. The storage is
/// bit-addressed but the API works with ranges over bytes.
#[derive(Debug, Clone)]
struct RefStorage {
index: RangeSet<usize>,
decoded: RangeSet<usize>,
offset: Vec<isize>,
max_len: usize,
}
impl RefStorage {
fn new(max_len: usize) -> Self {
Self {
index: RangeSet::default(),
decoded: RangeSet::default(),
offset: Vec::default(),
max_len: 8 * max_len,
}
}
fn add(&mut self, index: &Range<usize>, data: Vector<U8>) {
assert!(
index.start < index.end,
"Range should be valid for adding to reference storage"
);
assert_eq!(
index.len(),
data.len(),
"Provided index and vm references should have the same length"
);
let bit_index = 8 * index.start..8 * index.end;
assert!(
bit_index.is_disjoint(&self.index),
"Parts of the provided index have already been computed"
);
assert!(
bit_index.end <= self.max_len,
"Provided index should be smaller than max_len"
);
if bit_index.end > self.offset.len() {
self.offset.resize(bit_index.end, 0);
}
let mem_address = data.to_raw().ptr().as_usize() as isize;
let offset = mem_address - bit_index.start as isize;
self.index.union_mut(&bit_index);
self.offset[bit_index].fill(offset);
}
fn mark_decoded(&mut self, index: &RangeSet<usize>) {
let bit_index = to_bit_index(index);
self.decoded.union_mut(&bit_index);
}
fn get(&self, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
let bit_index = to_bit_index(index);
if bit_index.is_empty() || !bit_index.is_subset(&self.index) {
return Vec::new();
}
// Partition rangeset into ranges mapping to possibly disjunct memory locations.
//
// If the offset changes during iteration of a single range, it means that the
// backing memory is non-contigious and we need to split that range.
let mut transcript_refs = Vec::new();
for idx in bit_index.iter_ranges() {
let mut start = idx.start;
let mut end = idx.start;
let mut offset = self.offset[start];
for k in idx {
let next_offset = self.offset[k];
if next_offset == offset {
end += 1;
continue;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
start = k;
end = k + 1;
offset = next_offset;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
}
transcript_refs
}
fn compute_missing(&self, index: &RangeSet<usize>) -> RangeSet<usize> {
let byte_index = to_byte_index(&self.index);
index.difference(&byte_index)
}
fn decoded(&self) -> RangeSet<usize> {
to_byte_index(&self.decoded)
}
fn max_len(&self) -> usize {
self.max_len / 8
}
#[cfg(test)]
fn index(&self) -> RangeSet<usize> {
to_byte_index(&self.index)
}
}
fn to_bit_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut bit_index = RangeSet::default();
for r in index.iter_ranges() {
bit_index.union_mut(&(8 * r.start..8 * r.end));
}
bit_index
}
fn to_byte_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut byte_index = RangeSet::default();
for r in index.iter_ranges() {
let start = r.start;
let end = r.end;
assert!(
start.trailing_zeros() >= 3,
"start range should be divisible by 8"
);
assert!(
end.trailing_zeros() >= 3,
"end range should be divisible by 8"
);
let start = start >> 3;
let end = end >> 3;
byte_index.union_mut(&(start..end));
}
byte_index
}
#[cfg(test)]
mod tests {
use crate::commit::transcript::RefStorage;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use std::ops::Range;
#[rstest]
fn test_storage_add(
max_len: usize,
ranges: [Range<usize>; 6],
offsets: [isize; 6],
storage: RefStorage,
) {
let bit_ranges: Vec<Range<usize>> = ranges.iter().map(|r| 8 * r.start..8 * r.end).collect();
let bit_offsets: Vec<isize> = offsets.iter().map(|o| 8 * o).collect();
let mut expected_index: RangeSet<usize> = RangeSet::default();
expected_index.union_mut(&bit_ranges[0]);
expected_index.union_mut(&bit_ranges[1]);
expected_index.union_mut(&bit_ranges[2]);
expected_index.union_mut(&bit_ranges[3]);
expected_index.union_mut(&bit_ranges[4]);
expected_index.union_mut(&bit_ranges[5]);
assert_eq!(storage.index, expected_index);
let end = expected_index.end().unwrap();
let mut expected_offset = vec![0_isize; end];
expected_offset[bit_ranges[0].clone()].fill(bit_offsets[0]);
expected_offset[bit_ranges[1].clone()].fill(bit_offsets[1]);
expected_offset[bit_ranges[2].clone()].fill(bit_offsets[2]);
expected_offset[bit_ranges[3].clone()].fill(bit_offsets[3]);
expected_offset[bit_ranges[4].clone()].fill(bit_offsets[4]);
expected_offset[bit_ranges[5].clone()].fill(bit_offsets[5]);
assert_eq!(storage.offset, expected_offset);
assert_eq!(storage.decoded, RangeSet::default());
assert_eq!(storage.max_len, 8 * max_len);
}
#[rstest]
fn test_storage_get(ranges: [Range<usize>; 6], offsets: [isize; 6], storage: RefStorage) {
let mut index = RangeSet::default();
ranges.iter().for_each(|r| index.union_mut(r));
let data = storage.get(&index);
let mut data_recovered = Vec::new();
for (r, o) in ranges.iter().zip(offsets) {
data_recovered.push(vec(r.start as isize + o..r.end as isize + o));
}
// Merge possibly adjacent vectors.
//
// Two vectors are adjacent if
//
// - vectors are adjacent in memory.
// - transcript ranges of those vectors are adjacent, too.
let mut range_iter = ranges.iter();
let mut vec_iter = data_recovered.iter();
let mut data_expected = Vec::new();
let mut current_vec = vec_iter.next().unwrap().to_raw().to_range();
let mut current_range = range_iter.next().unwrap();
for (r, v) in range_iter.zip(vec_iter) {
let v_range = v.to_raw().to_range();
let start = v_range.start;
let end = v_range.end;
if current_vec.end == start && current_range.end == r.start {
current_vec.end = end;
} else {
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
current_vec = start..end;
current_range = r;
}
}
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
assert_eq!(data, data_expected);
}
#[rstest]
fn test_storage_compute_missing(storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(6..12));
range.union_mut(&(18..21));
range.union_mut(&(22..25));
range.union_mut(&(50..60));
let missing = storage.compute_missing(&range);
let mut missing_expected = RangeSet::default();
missing_expected.union_mut(&(8..12));
missing_expected.union_mut(&(20..21));
missing_expected.union_mut(&(50..60));
assert_eq!(missing, missing_expected);
}
#[rstest]
fn test_mark_decoded(mut storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(14..17));
range.union_mut(&(30..37));
storage.mark_decoded(&range);
let decoded = storage.decoded();
assert_eq!(range, decoded);
}
#[fixture]
fn max_len() -> usize {
1000
}
#[fixture]
fn ranges() -> [Range<usize>; 6] {
let r1 = 0..5;
let r2 = 5..8;
let r3 = 12..20;
let r4 = 22..26;
let r5 = 30..35;
let r6 = 35..38;
[r1, r2, r3, r4, r5, r6]
}
#[fixture]
fn offsets() -> [isize; 6] {
[7, 9, 20, 18, 30, 30]
}
// expected memory ranges: 8 * ranges + 8 * offsets
// 1. 56..96 do not merge with next one, because not adjacent in memory
// 2. 112..136
// 3. 256..320 do not merge with next one, adjacent in memory, but the ranges
// itself are not
// 4. 320..352
// 5. 480..520 merge with next one
// 6 520..544
//
//
// 1. 56..96, length: 5
// 2. 112..136, length: 3
// 3. 256..320, length: 8
// 4. 320..352, length: 4
// 5. 480..544, length: 8
#[fixture]
fn storage(max_len: usize, ranges: [Range<usize>; 6], offsets: [isize; 6]) -> RefStorage {
let [r1, r2, r3, r4, r5, r6] = ranges;
let [o1, o2, o3, o4, o5, o6] = offsets;
let mut storage = RefStorage::new(max_len);
storage.add(&r1, vec(r1.start as isize + o1..r1.end as isize + o1));
storage.add(&r2, vec(r2.start as isize + o2..r2.end as isize + o2));
storage.add(&r3, vec(r3.start as isize + o3..r3.end as isize + o3));
storage.add(&r4, vec(r4.start as isize + o4..r4.end as isize + o4));
storage.add(&r5, vec(r5.start as isize + o5..r5.end as isize + o5));
storage.add(&r6, vec(r6.start as isize + o6..r6.end as isize + o6));
storage
}
fn vec(range: Range<isize>) -> Vector<U8> {
let range = 8 * range.start as usize..8 * range.end as usize;
Vector::from_raw(Slice::from_range_unchecked(range))
}
}

View File

@@ -23,11 +23,11 @@ pub(crate) fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + !aad.len().is_multiple_of(16) as usize;
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + !ciphertext.len().is_multiple_of(16) as usize;
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);

View File

@@ -4,15 +4,16 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]
pub(crate) mod commit;
pub mod config;
pub(crate) mod context;
pub(crate) mod ghash;
pub(crate) mod map;
pub(crate) mod msg;
pub(crate) mod mux;
pub mod prover;
pub(crate) mod tag;
pub(crate) mod transcript_internal;
pub mod verifier;
pub(crate) mod zk_aes_ctr;
pub use tlsn_attestation as attestation;
pub use tlsn_core::{connection, hash, transcript};

View File

@@ -1,208 +0,0 @@
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::RangeSet;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
map: Vec<(usize, T)>,
}
impl<T> Default for RangeMap<T>
where
T: Item,
{
fn default() -> Self {
Self { map: Vec::new() }
}
}
impl<T> RangeMap<T>
where
T: Item,
{
pub(crate) fn new(map: Vec<(usize, T)>) -> Self {
let mut pos = 0;
for (idx, item) in &map {
assert!(
*idx >= pos,
"items must be sorted by index and non-overlapping"
);
pos = *idx + item.length();
}
Self { map }
}
/// Returns `true` if the map is empty.
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
/// Returns the keys of the map.
pub(crate) fn keys(&self) -> impl Iterator<Item = Range<usize>> {
self.map
.iter()
.map(|(idx, item)| *idx..*idx + item.length())
}
/// Returns the length of the map.
pub(crate) fn len(&self) -> usize {
self.map.iter().map(|(_, item)| item.length()).sum()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (Range<usize>, &T)> {
self.map
.iter()
.map(|(idx, item)| (*idx..*idx + item.length(), item))
}
pub(crate) fn get(&self, range: Range<usize>) -> Option<T::Slice<'_>> {
if range.start >= range.end {
return None;
}
// Find the item with the greatest start index <= range.start
let pos = match self.map.binary_search_by(|(idx, _)| idx.cmp(&range.start)) {
Ok(i) => i,
Err(0) => return None,
Err(i) => i - 1,
};
let (base, item) = &self.map[pos];
item.slice(range.start - *base..range.end - *base)
}
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter_ranges() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,
Err(i) => i - 1,
};
let (base, item) = self.map.get(pos)?;
if idx.start < *base || idx.end > *base + item.length() {
return None;
}
let start = idx.start - *base;
let end = start + idx.len();
map.push((
idx.start,
item.slice(start..end)
.expect("slice length is checked")
.into(),
));
}
Some(Self { map })
}
}
impl<T> FromIterator<(usize, T)> for RangeMap<T>
where
T: Item,
{
fn from_iter<I: IntoIterator<Item = (usize, T)>>(items: I) -> Self {
let mut pos = 0;
let mut map = Vec::new();
for (idx, item) in items {
assert!(
idx >= pos,
"items must be sorted by index and non-overlapping"
);
pos = idx + item.length();
map.push((idx, item));
}
Self { map }
}
}
pub(crate) trait Item: Sized {
type Slice<'a>: Into<Self>
where
Self: 'a;
fn length(&self) -> usize;
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>>;
}
impl Item for Vector<U8> {
type Slice<'a> = Vector<U8>;
fn length(&self) -> usize {
self.len()
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.get(range)
}
}
#[cfg(test)]
mod tests {
use super::*;
impl Item for Range<usize> {
type Slice<'a> = Range<usize>;
fn length(&self) -> usize {
self.end - self.start
}
fn slice(&self, range: Range<usize>) -> Option<Self> {
if range.end > self.end - self.start {
return None;
}
Some(range.start + self.start..range.end + self.start)
}
}
#[test]
fn test_range_map() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
assert_eq!(map.get(0..4), Some(10..14));
assert_eq!(map.get(10..14), Some(20..24));
assert_eq!(map.get(20..22), Some(30..32));
assert_eq!(map.get(0..2), Some(10..12));
assert_eq!(map.get(11..13), Some(21..23));
assert_eq!(map.get(0..10), None);
assert_eq!(map.get(10..20), None);
assert_eq!(map.get(20..30), None);
}
#[test]
fn test_range_map_index() {
let map = RangeMap::from_iter([(0, 10..14), (10, 20..24), (20, 30..32)]);
let idx = RangeSet::from([0..4, 10..14, 20..22]);
assert_eq!(map.index(&idx), Some(map.clone()));
let idx = RangeSet::from(25..30);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from(15..20);
assert_eq!(map.index(&idx), None);
let idx = RangeSet::from([1..3, 11..12, 13..14, 21..22]);
assert_eq!(
map.index(&idx),
Some(RangeMap::from_iter([
(1, 11..13),
(11, 21..22),
(13, 23..24),
(21, 31..32)
]))
);
}
}

14
crates/tlsn/src/msg.rs Normal file
View File

@@ -0,0 +1,14 @@
//! Message types.
use serde::{Deserialize, Serialize};
use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,
/// Server identity data.
pub data: HandshakeData,
}

View File

@@ -3,39 +3,45 @@
mod config;
mod error;
mod future;
mod prove;
pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError;
pub use future::ProverFuture;
use rustls_pki_types::CertificateDer;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use webpki::anchor_from_trusted_cert;
use crate::{Role, context::build_mt_context, mux::attach_mux, tag::verify_tags};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use rustls_pki_types::CertificateDer;
use serio::SinkExt;
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
connection::ServerName,
ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
commit::{ProvingState, TranscriptRefs},
context::build_mt_context,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
@@ -101,6 +107,22 @@ impl Prover<state::Initialized> {
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_sent_data(),
)?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Prover);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(
&mut *vm_lock.zk(),
self.config.protocol_config().max_recv_data(),
)?;
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -116,6 +138,8 @@ impl Prover<state::Initialized> {
mux_ctrl,
mux_fut,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
},
@@ -141,6 +165,8 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
..
@@ -255,6 +281,9 @@ impl Prover<state::Setup> {
.to_transcript()
.expect("transcript is complete");
let (sent_len, recv_len) = transcript.len();
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Prover {
config: self.config,
span: self.span,
@@ -263,9 +292,13 @@ impl Prover<state::Setup> {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred: false,
},
})
}
@@ -299,29 +332,56 @@ impl Prover<state::Committed> {
///
/// * `config` - The disclosure configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
pub async fn prove(&mut self, config: ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed {
mux_fut,
ctx,
vm,
keys,
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred,
..
} = &mut self.state;
let output = mux_fut
.poll_with(prove::prove(
ctx,
vm,
keys,
self.config.server_name(),
transcript,
tls_transcript,
config,
))
// Create and send prove payload.
let server_name = self.config.server_name();
let handshake = config
.server_identity()
.then(|| (server_name.clone(), HandshakeData::new(tls_transcript)));
let partial = if let Some((reveal_sent, reveal_recv)) = config.reveal() {
Some(transcript.to_partial(reveal_sent.clone(), reveal_recv.clone()))
} else {
None
};
let payload = ProveRequest::new(&config, partial, handshake);
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
.await?;
let proving_state = ProvingState::for_prover(
config,
tls_transcript,
transcript,
transcript_refs,
*encodings_transferred,
);
let (output, encodings_executed) = mux_fut
.poll_with(
proving_state
.prove(vm, ctx, zk_aes_ctr_sent, zk_aes_ctr_recv, keys.clone())
.map_err(ProverError::from),
)
.await?;
*encodings_transferred = encodings_executed;
Ok(output)
}

View File

@@ -1,8 +1,6 @@
use std::{error::Error, fmt};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
use std::{error::Error, fmt};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -42,13 +40,6 @@ impl ProverError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn commit<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Commit, source)
}
}
#[derive(Debug)]
@@ -110,8 +101,14 @@ impl From<MpcTlsError> for ProverError {
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
impl From<ZkAesCtrError> for ProverError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<CommitError> for ProverError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -1,187 +0,0 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use serio::SinkExt;
use tlsn_core::{
ProveConfig, ProveRequest, ProverOutput,
connection::{HandshakeData, ServerName},
transcript::{
ContentType, Direction, TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret,
},
};
use crate::{
prover::ProverError,
transcript_internal::{
TranscriptRefs,
auth::prove_plaintext,
commit::{
encoding::{self, MacStore},
hash::prove_hash,
},
},
};
pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
server_name: &ServerName,
transcript: &Transcript,
tls_transcript: &TlsTranscript,
config: &ProveConfig,
) -> Result<ProverOutput, ProverError> {
let mut output = ProverOutput {
transcript_commitments: Vec::default(),
transcript_secrets: Vec::default(),
};
let request = ProveRequest {
handshake: config.server_identity().then(|| {
(
server_name.clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: config
.reveal()
.map(|(sent, recv)| transcript.to_partial(sent.clone(), recv.clone())),
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
ctx.io_mut()
.send(request)
.await
.map_err(ProverError::from)?;
let (reveal_sent, reveal_recv) = config.reveal().cloned().unwrap_or_default();
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = config.transcript_commit() {
commit_config
.iter_hash()
.for_each(|((direction, idx), _)| match direction {
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
commit_config
.iter_encoding()
.for_each(|(direction, idx)| match direction {
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
}
let transcript_refs = TranscriptRefs {
sent: prove_plaintext(
vm,
keys.client_write_key,
keys.client_write_iv,
transcript.sent(),
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_sent,
&commit_sent,
)
.map_err(ProverError::commit)?,
recv: prove_plaintext(
vm,
keys.server_write_key,
keys.server_write_iv,
transcript.received(),
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
&reveal_recv,
&commit_recv,
)
.map_err(ProverError::commit)?,
};
let hash_commitments = if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_hash()
{
Some(
prove_hash(
vm,
&transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
)
} else {
None
};
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
if let Some(commit_config) = config.transcript_commit()
&& commit_config.has_encoding()
{
let mut sent_ranges = RangeSet::default();
let mut recv_ranges = RangeSet::default();
for (dir, idx) in commit_config.iter_encoding() {
match dir {
Direction::Sent => sent_ranges.union_mut(idx),
Direction::Received => recv_ranges.union_mut(idx),
}
}
let sent_map = transcript_refs
.sent
.index(&sent_ranges)
.expect("indices are valid");
let recv_map = transcript_refs
.recv
.index(&recv_ranges)
.expect("indices are valid");
let (commitment, tree) = encoding::receive(
ctx,
vm,
*commit_config.encoding_hash_alg(),
&sent_map,
&recv_map,
commit_config.iter_encoding(),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
Ok(output)
}

View File

@@ -9,8 +9,10 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
};
/// Entry state
@@ -23,6 +25,8 @@ pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) mpc_tls: MpcTlsLeader,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}
@@ -35,9 +39,13 @@ pub struct Committed {
pub(crate) mux_fut: MuxFuture,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -1,16 +0,0 @@
pub(crate) mod auth;
pub(crate) mod commit;
use mpz_memory_core::{Vector, binary::U8};
use crate::map::RangeMap;
/// Maps transcript ranges to VM references.
pub(crate) type ReferenceMap = RangeMap<Vector<U8>>;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
pub(crate) struct TranscriptRefs {
pub(crate) sent: ReferenceMap,
pub(crate) recv: ReferenceMap,
}

View File

@@ -1,639 +0,0 @@
use std::sync::Arc;
use aes::Aes128;
use ctr::{
Ctr32BE,
cipher::{KeyIvInit, StreamCipher, StreamCipherSeek},
};
use mpz_circuits::circuits::{AES128, xor};
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{
Array, DecodeFutureTyped, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{Difference, RangeSet, Union};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
pub(crate) fn prove_plaintext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: &[u8],
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
let records = RecordParams::from_iter(records).collect::<Vec<_>>();
if is_reveal_all {
drop(vm.decode(key).map_err(PlaintextAuthError::vm)?);
drop(vm.decode(iv).map_err(PlaintextAuthError::vm)?);
for (range, slice) in plaintext_refs.iter() {
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
.iter()
{
vm.mark_private(*slice).map_err(PlaintextAuthError::vm)?;
}
for (_, slice) in plaintext_refs
.index(reveal)
.expect("all ranges are allocated")
.iter()
{
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
}
for (range, slice) in plaintext_refs.iter() {
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
let ciphertext = alloc_ciphertext(vm, key, iv, plaintext_refs.clone(), &records)?;
for (_, slice) in ciphertext.iter() {
drop(vm.decode(*slice).map_err(PlaintextAuthError::vm)?);
}
}
Ok(plaintext_refs)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn verify_plaintext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: &'a [u8],
ciphertext: &'a [u8],
records: impl IntoIterator<Item = &'a Record>,
reveal: &RangeSet<usize>,
commit: &RangeSet<usize>,
) -> Result<(ReferenceMap, PlaintextProof<'a>), PlaintextAuthError> {
let is_reveal_all = reveal == (0..plaintext.len());
let alloc_ranges = if is_reveal_all {
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
let records = RecordParams::from_iter(records).collect::<Vec<_>>();
let plaintext_proof = if is_reveal_all {
let key = vm.decode(key).map_err(PlaintextAuthError::vm)?;
let iv = vm.decode(iv).map_err(PlaintextAuthError::vm)?;
for (range, slice) in plaintext_refs.iter() {
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
PlaintextProof(ProofInner::WithKey {
key,
iv,
records,
plaintext,
ciphertext,
})
} else {
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
.iter()
{
vm.mark_blind(*slice).map_err(PlaintextAuthError::vm)?;
}
for (range, slice) in plaintext_refs
.index(reveal)
.expect("all ranges are allocated")
.iter()
{
vm.mark_public(*slice).map_err(PlaintextAuthError::vm)?;
vm.assign(*slice, plaintext[range].to_vec())
.map_err(PlaintextAuthError::vm)?;
}
for (_, slice) in plaintext_refs.iter() {
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
let ciphertext_map = alloc_ciphertext(vm, key, iv, plaintext_refs.clone(), &records)?;
let mut ciphertexts = Vec::new();
for (range, chunk) in ciphertext_map.iter() {
ciphertexts.push((
&ciphertext[range],
vm.decode(*chunk).map_err(PlaintextAuthError::vm)?,
));
}
PlaintextProof(ProofInner::WithZk { ciphertexts })
};
Ok((plaintext_refs, plaintext_proof))
}
fn alloc_plaintext(
vm: &mut dyn Vm<Binary>,
ranges: &RangeSet<usize>,
) -> Result<ReferenceMap, PlaintextAuthError> {
let len = ranges.len();
if len == 0 {
return Ok(ReferenceMap::default());
}
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_ciphertext<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
plaintext: ReferenceMap,
records: impl IntoIterator<Item = &'a RecordParams>,
) -> Result<ReferenceMap, PlaintextAuthError> {
if plaintext.is_empty() {
return Ok(ReferenceMap::default());
}
let ranges = RangeSet::from(plaintext.keys().collect::<Vec<_>>());
let keystream = alloc_keystream(vm, key, iv, &ranges, records)?;
let mut builder = Call::builder(Arc::new(xor(ranges.len() * 8)));
for (_, slice) in plaintext.iter() {
builder = builder.arg(*slice);
}
for slice in keystream {
builder = builder.arg(slice);
}
let call = builder.build().expect("call should be valid");
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_keystream<'a>(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
ranges: &RangeSet<usize>,
records: impl IntoIterator<Item = &'a RecordParams>,
) -> Result<Vec<Vector<U8>>, PlaintextAuthError> {
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
let mut current_block = None;
loop {
let Some(range) = current_range.take().or_else(|| range_iter.next()) else {
return Ok(keystream);
};
let record_range = pos..pos + record.len;
if range.start >= record_range.end {
current_range = Some(range);
break;
}
// Range with record offset applied.
let offset_range = range.start - pos..range.end - pos;
let explicit_nonce = if let Some(explicit_nonce) = explicit_nonce {
explicit_nonce
} else {
let nonce = alloc_explicit_nonce(vm, record.explicit_nonce.clone())?;
explicit_nonce = Some(nonce);
nonce
};
const BLOCK_SIZE: usize = 16;
let block_num = offset_range.start / BLOCK_SIZE;
let block = if let Some((current_block_num, block)) = current_block.take()
&& current_block_num == block_num
{
block
} else {
let block = alloc_block(vm, key, iv, explicit_nonce, block_num)?;
current_block = Some((block_num, block));
block
};
// Range within the block.
let block_range_start = offset_range.start % BLOCK_SIZE;
let len =
(range.end.min(record_range.end) - range.start).min(BLOCK_SIZE - block_range_start);
let block_range = block_range_start..block_range_start + len;
keystream.push(block.get(block_range).expect("range is checked"));
// If the range extends past the block, process the tail.
if range.start + len < range.end {
current_range = Some(range.start + len..range.end);
}
}
pos += record.len;
}
Err(ErrorRepr::OutOfBounds.into())
}
fn alloc_explicit_nonce(
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
) -> Result<Vector<U8>, PlaintextAuthError> {
const EXPLICIT_NONCE_LEN: usize = 8;
let nonce = vm
.alloc_vec::<U8>(EXPLICIT_NONCE_LEN)
.map_err(PlaintextAuthError::vm)?;
vm.mark_public(nonce).map_err(PlaintextAuthError::vm)?;
vm.assign(nonce, explicit_nonce)
.map_err(PlaintextAuthError::vm)?;
vm.commit(nonce).map_err(PlaintextAuthError::vm)?;
Ok(nonce)
}
fn alloc_block(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
explicit_nonce: Vector<U8>,
block: usize,
) -> Result<Vector<U8>, PlaintextAuthError> {
let ctr: Array<U8, 4> = vm.alloc().map_err(PlaintextAuthError::vm)?;
vm.mark_public(ctr).map_err(PlaintextAuthError::vm)?;
const START_CTR: u32 = 2;
vm.assign(ctr, (START_CTR + block as u32).to_be_bytes())
.map_err(PlaintextAuthError::vm)?;
vm.commit(ctr).map_err(PlaintextAuthError::vm)?;
let block: Array<U8, 16> = vm
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(iv)
.arg(explicit_nonce)
.arg(ctr)
.build()
.expect("call should be valid"),
)
.map_err(PlaintextAuthError::vm)?;
Ok(Vector::from(block))
}
struct RecordParams {
explicit_nonce: Vec<u8>,
len: usize,
}
impl RecordParams {
fn from_iter<'a>(records: impl IntoIterator<Item = &'a Record>) -> impl Iterator<Item = Self> {
records.into_iter().map(|record| Self {
explicit_nonce: record.explicit_nonce.clone(),
len: record.ciphertext.len(),
})
}
}
#[must_use]
pub(crate) struct PlaintextProof<'a>(ProofInner<'a>);
impl<'a> PlaintextProof<'a> {
pub(crate) fn verify(self) -> Result<(), PlaintextAuthError> {
match self.0 {
ProofInner::WithKey {
mut key,
mut iv,
records,
plaintext,
ciphertext,
} => {
let key = key
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(ErrorRepr::MissingDecoding)?;
let iv = iv
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(ErrorRepr::MissingDecoding)?;
verify_plaintext_with_key(key, iv, &records, plaintext, ciphertext)?;
}
ProofInner::WithZk { ciphertexts } => {
for (expected, mut actual) in ciphertexts {
let actual = actual
.try_recv()
.map_err(PlaintextAuthError::vm)?
.ok_or(PlaintextAuthError(ErrorRepr::MissingDecoding))?;
if actual != expected {
return Err(PlaintextAuthError(ErrorRepr::InvalidPlaintext));
}
}
}
}
Ok(())
}
}
enum ProofInner<'a> {
WithKey {
key: DecodeFutureTyped<BitVec, [u8; 16]>,
iv: DecodeFutureTyped<BitVec, [u8; 4]>,
records: Vec<RecordParams>,
plaintext: &'a [u8],
ciphertext: &'a [u8],
},
WithZk {
// (expected, actual)
#[allow(clippy::type_complexity)]
ciphertexts: Vec<(&'a [u8], DecodeFutureTyped<BitVec, Vec<u8>>)>,
},
}
fn aes_ctr_apply_keystream(key: &[u8; 16], iv: &[u8; 4], explicit_nonce: &[u8], input: &mut [u8]) {
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(&explicit_nonce[..8]);
const START_CTR: u32 = 2;
let mut cipher = Ctr32BE::<Aes128>::new(key.into(), &full_iv.into());
cipher
.try_seek(START_CTR * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(input);
}
fn verify_plaintext_with_key<'a>(
key: [u8; 16],
iv: [u8; 4],
records: impl IntoIterator<Item = &'a RecordParams>,
plaintext: &[u8],
ciphertext: &[u8],
) -> Result<(), PlaintextAuthError> {
let mut pos = 0;
let mut text = Vec::new();
for record in records {
text.clear();
text.extend_from_slice(&plaintext[pos..pos + record.len]);
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut text);
if text != ciphertext[pos..pos + record.len] {
return Err(PlaintextAuthError(ErrorRepr::InvalidPlaintext));
}
pos += record.len;
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
#[error("plaintext authentication error: {0}")]
pub(crate) struct PlaintextAuthError(#[from] ErrorRepr);
impl PlaintextAuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("plaintext out of bounds of records. This should never happen and is an internal bug.")]
OutOfBounds,
#[error("missing decoding")]
MissingDecoding,
#[error("plaintext does not match ciphertext")]
InvalidPlaintext,
}
#[cfg(test)]
#[allow(clippy::all)]
mod tests {
use super::*;
use mpz_common::context::test_st_context;
use mpz_ideal_vm::IdealVm;
use mpz_vm_core::prelude::*;
use rand::{Rng, SeedableRng, rngs::StdRng};
use rstest::*;
use std::ops::Range;
fn build_vm(key: [u8; 16], iv: [u8; 4]) -> (IdealVm, Array<U8, 16>, Array<U8, 4>) {
let mut vm = IdealVm::new();
let key_ref = vm.alloc::<Array<U8, 16>>().unwrap();
let iv_ref = vm.alloc::<Array<U8, 4>>().unwrap();
vm.mark_public(key_ref).unwrap();
vm.mark_public(iv_ref).unwrap();
vm.assign(key_ref, key).unwrap();
vm.assign(iv_ref, iv).unwrap();
vm.commit(key_ref).unwrap();
vm.commit(iv_ref).unwrap();
(vm, key_ref, iv_ref)
}
fn expected_aes_ctr<'a>(
key: [u8; 16],
iv: [u8; 4],
records: impl IntoIterator<Item = &'a RecordParams>,
ranges: &RangeSet<usize>,
) -> Vec<u8> {
let mut keystream = Vec::new();
let mut pos = 0;
for record in records {
let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter_ranges() {
range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len);
if range.start < range.end {
keystream
.extend_from_slice(&record_keystream[range.start - pos..range.end - pos]);
}
}
pos += record.len;
}
keystream
}
#[rstest]
#[case::single_record_empty([0], [])]
#[case::multiple_empty_records_empty([0, 0], [])]
#[case::multiple_records_empty([128, 64], [])]
#[case::single_block_full([16], [0..16])]
#[case::single_block_partial([16], [2..14])]
#[case::partial_block_full([15], [0..15])]
#[case::out_of_bounds([16], [0..17])]
#[case::multiple_records_full([128, 63, 33, 15, 4], [0..243])]
#[case::multiple_records_partial([128, 63, 33, 15, 4], [1..15, 16..17, 18..19, 126..130, 224..225, 242..243])]
#[tokio::test]
async fn test_alloc_keystream(
#[case] record_lens: impl IntoIterator<Item = usize>,
#[case] ranges: impl IntoIterator<Item = Range<usize>>,
) {
let mut rng = StdRng::seed_from_u64(0);
let mut key = [0u8; 16];
let mut iv = [0u8; 4];
rng.fill(&mut key);
rng.fill(&mut iv);
let mut total_len = 0;
let records = record_lens
.into_iter()
.map(|len| {
let mut explicit_nonce = [0u8; 8];
rng.fill(&mut explicit_nonce);
total_len += len;
RecordParams {
explicit_nonce: explicit_nonce.to_vec(),
len,
}
})
.collect::<Vec<_>>();
let ranges = RangeSet::from(ranges.into_iter().collect::<Vec<_>>());
let is_out_of_bounds = ranges.end().unwrap_or(0) > total_len;
let (mut ctx, _) = test_st_context(1024);
let (mut vm, key_ref, iv_ref) = build_vm(key, iv);
let keystream = match alloc_keystream(&mut vm, key_ref, iv_ref, &ranges, &records) {
Ok(_) if is_out_of_bounds => panic!("should be out of bounds"),
Ok(keystream) => keystream,
Err(PlaintextAuthError(ErrorRepr::OutOfBounds)) if is_out_of_bounds => {
return;
}
Err(e) => panic!("unexpected error: {:?}", e),
};
vm.execute(&mut ctx).await.unwrap();
let keystream: Vec<u8> = keystream
.iter()
.flat_map(|slice| vm.get(*slice).unwrap().unwrap())
.collect();
assert_eq!(keystream.len(), ranges.len());
let expected = expected_aes_ctr(key, iv, &records, &ranges);
assert_eq!(keystream, expected);
}
#[rstest]
#[case::single_record_empty([0])]
#[case::single_record([32])]
#[case::multiple_records([128, 63, 33, 15, 4])]
#[case::multiple_records_with_empty([128, 63, 33, 0, 15, 4])]
fn test_verify_plaintext_with_key(
#[case] record_lens: impl IntoIterator<Item = usize>,
#[values(false, true)] tamper: bool,
) {
let mut rng = StdRng::seed_from_u64(0);
let mut key = [0u8; 16];
let mut iv = [0u8; 4];
rng.fill(&mut key);
rng.fill(&mut iv);
let mut total_len = 0;
let records = record_lens
.into_iter()
.map(|len| {
let mut explicit_nonce = [0u8; 8];
rng.fill(&mut explicit_nonce);
total_len += len;
RecordParams {
explicit_nonce: explicit_nonce.to_vec(),
len,
}
})
.collect::<Vec<_>>();
let mut plaintext = vec![0u8; total_len];
rng.fill(plaintext.as_mut_slice());
let mut ciphertext = plaintext.clone();
expected_aes_ctr(key, iv, &records, &(0..total_len).into())
.iter()
.zip(ciphertext.iter_mut())
.for_each(|(key, pt)| {
*pt ^= *key;
});
if tamper {
plaintext.first_mut().map(|pt| *pt ^= 1);
}
match verify_plaintext_with_key(key, iv, &records, &plaintext, &ciphertext) {
Ok(_) if tamper && !plaintext.is_empty() => panic!("should be invalid"),
Err(e) if !tamper => panic!("unexpected error: {:?}", e),
_ => {}
}
}
}

View File

@@ -1,4 +0,0 @@
//! Plaintext commitment and proof of encryption.
pub(crate) mod encoding;
pub(crate) mod hash;

View File

@@ -1,283 +0,0 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_common::Context;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
EncodingTree, EncodingTreeError, new_encoder,
},
},
};
use crate::{
map::{Item, RangeMap},
transcript_internal::ReferenceMap,
};
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers encodings for the provided plaintext ranges.
pub(crate) async fn transfer<K: KeyStore>(
ctx: &mut Context,
store: &K,
sent: &ReferenceMap,
recv: &ReferenceMap,
) -> Result<(EncoderSecret, EncodingCommitment), EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), store.delta().as_block().to_bytes());
let encoder = new_encoder(&secret);
// Collects the encodings for the provided plaintext ranges.
fn collect_encodings(
encoder: &impl Encoder,
store: &impl KeyStore,
direction: Direction,
map: &ReferenceMap,
) -> Vec<u8> {
let mut encodings = Vec::with_capacity(map.len() * ENCODING_SIZE);
for (range, chunk) in map.iter() {
let start = encodings.len();
encoder.encode_range(direction, range, &mut encodings);
let keys = store
.get_keys(*chunk)
.expect("keys are present for provided plaintext ranges");
encodings[start..]
.iter_mut()
.zip(keys.iter().flat_map(|key| key.as_block().as_bytes()))
.for_each(|(encoding, key)| {
*encoding ^= *key;
});
}
encodings
}
let encodings = Encodings {
sent: collect_encodings(&encoder, store, Direction::Sent, sent),
recv: collect_encodings(&encoder, store, Direction::Received, recv),
};
let frame_limit = ctx
.io()
.limit()
.saturating_add(encodings.sent.len() + encodings.recv.len());
ctx.io_mut().with_limit(frame_limit).send(encodings).await?;
let root = ctx.io_mut().expect_next().await?;
Ok((secret, EncodingCommitment { root }))
}
/// Receives and commits to the encodings for the provided plaintext ranges.
pub(crate) async fn receive<M: MacStore>(
ctx: &mut Context,
store: &M,
hash_alg: HashAlgId,
sent: &ReferenceMap,
recv: &ReferenceMap,
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
let hasher: &(dyn HashAlgorithm + Send + Sync) = match hash_alg {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ErrorRepr::UnsupportedHashAlgorithm(alg).into());
}
};
let (sent_len, recv_len) = (sent.len(), recv.len());
let frame_limit = ctx
.io()
.limit()
.saturating_add(ENCODING_SIZE * (sent_len + recv_len));
let encodings: Encodings = ctx.io_mut().with_limit(frame_limit).expect_next().await?;
if encodings.sent.len() != sent_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_len,
got: encodings.sent.len() / ENCODING_SIZE,
}
.into());
}
if encodings.recv.len() != recv_len * ENCODING_SIZE {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_len,
got: encodings.recv.len() / ENCODING_SIZE,
}
.into());
}
// Collects a map of plaintext ranges to their encodings.
fn collect_map(
store: &impl MacStore,
mut encodings: Vec<u8>,
map: &ReferenceMap,
) -> RangeMap<EncodingSlice> {
let mut encoding_map = Vec::new();
let mut pos = 0;
for (range, chunk) in map.iter() {
let macs = store
.get_macs(*chunk)
.expect("MACs are present for provided plaintext ranges");
let encoding = &mut encodings[pos..pos + range.len() * ENCODING_SIZE];
encoding
.iter_mut()
.zip(macs.iter().flat_map(|mac| mac.as_bytes()))
.for_each(|(encoding, mac)| {
*encoding ^= *mac;
});
encoding_map.push((range.start, EncodingSlice::from(&(*encoding))));
pos += range.len() * ENCODING_SIZE;
}
RangeMap::new(encoding_map)
}
let provider = Provider {
sent: collect_map(store, encodings.sent, sent),
recv: collect_map(store, encodings.recv, recv),
};
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
ctx.io_mut().send(root.clone()).await?;
let commitment = EncodingCommitment { root };
Ok((commitment, tree))
}
pub(crate) trait KeyStore {
fn delta(&self) -> &Delta;
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]>;
}
impl KeyStore for crate::verifier::Zk {
fn delta(&self) -> &Delta {
crate::verifier::Zk::delta(self)
}
fn get_keys(&self, data: Vector<U8>) -> Option<&[Key]> {
self.get_keys(data).ok()
}
}
pub(crate) trait MacStore {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]>;
}
impl MacStore for crate::prover::Zk {
fn get_macs(&self, data: Vector<U8>) -> Option<&[Mac]> {
self.get_macs(data).ok()
}
}
#[derive(Debug)]
struct Provider {
sent: RangeMap<EncodingSlice>,
recv: RangeMap<EncodingSlice>,
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let encoding = encodings.get(range).ok_or(EncodingProviderError)?;
dest.extend_from_slice(encoding);
Ok(())
}
}
#[derive(Debug)]
struct EncodingSlice(Vec<u8>);
impl From<&[u8]> for EncodingSlice {
fn from(value: &[u8]) -> Self {
Self(value.to_vec())
}
}
impl Item for EncodingSlice {
type Slice<'a>
= &'a [u8]
where
Self: 'a;
fn length(&self) -> usize {
self.0.len() / ENCODING_SIZE
}
fn slice<'a>(&'a self, range: Range<usize>) -> Option<Self::Slice<'a>> {
self.0
.get(range.start * ENCODING_SIZE..range.end * ENCODING_SIZE)
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("unsupported hash algorithm: {0}")]
UnsupportedHashAlgorithm(HashAlgId),
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}

View File

@@ -1,244 +0,0 @@
//! Plaintext hash commitments.
use std::collections::HashMap;
use mpz_core::bitvec::BitVec;
use mpz_hash::{blake3::Blake3, sha256::Sha256};
use mpz_memory_core::{
DecodeFutureTyped, MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
Direction,
hash::{PlaintextHash, PlaintextHashSecret},
},
};
use crate::{Role, transcript_internal::TranscriptRefs};
/// Future which will resolve to the committed hash values.
#[derive(Debug)]
pub(crate) struct HashCommitFuture {
#[allow(clippy::type_complexity)]
futs: Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
DecodeFutureTyped<BitVec, Vec<u8>>,
)>,
}
impl HashCommitFuture {
/// Tries to receive the value, returning an error if the value is not
/// ready.
pub(crate) fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
let mut output = Vec::new();
for (direction, idx, alg, mut fut) in self.futs {
let hash = fut
.try_recv()
.map_err(|_| HashCommitError::decode())?
.ok_or_else(HashCommitError::decode)?;
output.push(PlaintextHash {
direction,
idx,
hash: TypedHash {
alg,
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
},
});
}
Ok(output)
}
}
/// Prove plaintext hash commitments.
pub(crate) fn prove_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<(HashCommitFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let mut futs = Vec::new();
let mut secrets = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Prover, refs, idxs)?
{
let blinder: Blinder = rand::random();
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx.clone(), alg, hash_fut));
secrets.push(PlaintextHashSecret {
direction,
idx,
blinder,
alg,
});
}
Ok((HashCommitFuture { futs }, secrets))
}
/// Verify plaintext hash commitments.
pub(crate) fn verify_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<HashCommitFuture, HashCommitError> {
let mut futs = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Verifier, refs, idxs)?
{
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx, alg, hash_fut));
}
Ok(HashCommitFuture { futs })
}
#[derive(Clone)]
enum Hasher {
Sha256(Sha256),
Blake3(Blake3),
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn hash_commit_inner(
vm: &mut dyn Vm<Binary>,
role: Role,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<
Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
Array<U8, 32>,
Vector<U8>,
)>,
HashCommitError,
> {
let mut output = Vec::new();
let mut hashers = HashMap::new();
for (direction, idx, alg) in idxs {
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
}
let hash = match alg {
HashAlgId::SHA256 => {
let mut hasher = if let Some(Hasher::Sha256(hasher)) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Sha256(hasher.clone()));
hasher
};
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
HashAlgId::BLAKE3 => {
let mut hasher = if let Some(Hasher::Blake3(hasher)) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Blake3::new(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Blake3(hasher.clone()));
hasher
};
let refs = match direction {
Direction::Sent => &refs.sent,
Direction::Received => &refs.recv,
};
for range in idx.iter_ranges() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
}
hasher
.update(vm, &blinder)
.map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => {
return Err(HashCommitError::unsupported_alg(alg));
}
};
output.push((direction, idx, alg, hash, blinder));
}
Ok(output)
}
/// Error type for hash commitments.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct HashCommitError(#[from] ErrorRepr);
impl HashCommitError {
fn decode() -> Self {
Self(ErrorRepr::Decode)
}
fn convert(e: &'static str) -> Self {
Self(ErrorRepr::Convert(e))
}
fn hasher<E>(e: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self(ErrorRepr::Hasher(e.into()))
}
fn unsupported_alg(alg: HashAlgId) -> Self {
Self(ErrorRepr::UnsupportedAlg { alg })
}
}
#[derive(Debug, thiserror::Error)]
#[error("hash commit error: {0}")]
enum ErrorRepr {
#[error("VM error: {0}")]
Vm(VmError),
#[error("failed to decode hash")]
Decode,
#[error("failed to convert hash: {0}")]
Convert(&'static str),
#[error("unsupported hash algorithm: {alg}")]
UnsupportedAlg { alg: HashAlgId },
#[error("hasher error: {0}")]
Hasher(Box<dyn std::error::Error + Send + Sync>),
}
impl From<VmError> for HashCommitError {
fn from(value: VmError) -> Self {
Self(ErrorRepr::Vm(value))
}
}

View File

@@ -1,11 +1,8 @@
//! Verifier.
pub(crate) mod config;
mod config;
mod error;
pub mod state;
mod verify;
use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
@@ -14,9 +11,8 @@ pub use tlsn_core::{
webpki::ServerCertVerifier,
};
use crate::{
Role, config::ProtocolConfig, context::build_mt_context, mux::attach_mux, tag::verify_tags,
};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
@@ -26,14 +22,24 @@ use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use tlsn_core::{
ProveRequest,
connection::{ConnectionInfo, ServerName},
transcript::TlsTranscript,
transcript::{ContentType, TlsTranscript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument};
use crate::{
Role,
commit::{ProvingState, TranscriptRefs},
config::ProtocolConfig,
context::build_mt_context,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
@@ -102,12 +108,23 @@ impl Verifier<state::Initialized> {
})
.await?;
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, ctx);
let delta = Delta::random(&mut rand::rng());
let (vm, mut mpc_tls) = build_mpc_tls(&self.config, &protocol_config, delta, ctx);
// Allocate resources for MPC-TLS in the VM.
let mut keys = mpc_tls.alloc()?;
let vm_lock = vm.try_lock().expect("VM is not locked");
translate_keys(&mut keys, &vm_lock)?;
// Allocate for committing to plaintext.
let mut zk_aes_ctr_sent = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_sent.set_key(keys.client_write_key, keys.client_write_iv);
zk_aes_ctr_sent.alloc(&mut *vm_lock.zk(), protocol_config.max_sent_data())?;
let mut zk_aes_ctr_recv = ZkAesCtr::new(Role::Verifier);
zk_aes_ctr_recv.set_key(keys.server_write_key, keys.server_write_iv);
zk_aes_ctr_recv.alloc(&mut *vm_lock.zk(), protocol_config.max_recv_data())?;
drop(vm_lock);
debug!("setting up mpc-tls");
@@ -122,7 +139,10 @@ impl Verifier<state::Initialized> {
state: state::Setup {
mux_ctrl,
mux_fut,
delta,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
},
@@ -160,7 +180,10 @@ impl Verifier<state::Setup> {
let state::Setup {
mux_ctrl,
mut mux_fut,
delta,
mpc_tls,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
vm,
keys,
} = self.state;
@@ -210,16 +233,47 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Verifier {
config: self.config,
span: self.span,
state: state::Committed {
mux_ctrl,
mux_fut,
delta,
ctx,
vm,
keys,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name: None,
encodings_transferred: false,
},
})
}
@@ -244,33 +298,45 @@ impl Verifier<state::Committed> {
let state::Committed {
mux_fut,
ctx,
delta,
vm,
keys,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name,
encodings_transferred,
..
} = &mut self.state;
let cert_verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let request = mux_fut
let payload: ProveRequest = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let output = mux_fut
.poll_with(verify::verify(
ctx,
let proving_state = ProvingState::for_verifier(
payload,
tls_transcript,
transcript_refs,
verified_server_name.clone(),
*encodings_transferred,
);
let (output, encodings_executed) = mux_fut
.poll_with(proving_state.verify(
vm,
keys,
&cert_verifier,
tls_transcript,
request,
ctx,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys.clone(),
*delta,
self.config.root_store(),
))
.await?;
*verified_server_name = output.server_name.clone();
*encodings_transferred = encodings_executed;
Ok(output)
}
@@ -294,11 +360,11 @@ impl Verifier<state::Committed> {
fn build_mpc_tls(
config: &VerifierConfig,
protocol_config: &ProtocolConfig,
delta: Delta,
ctx: Context,
) -> (Arc<Mutex<Deap<Mpc, Zk>>>, MpcTlsFollower) {
let mut rng = rand::rng();
let delta = Delta::random(&mut rng);
let base_ot_send = mpz_ot::chou_orlandi::Sender::default();
let base_ot_recv = mpz_ot::chou_orlandi::Receiver::default();
let rcot_send = mpz_ot::kos::Sender::new(

View File

@@ -12,7 +12,7 @@ use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
#[builder(pattern = "owned")]
pub struct VerifierConfig {
protocol_config_validator: ProtocolConfigValidator,
#[builder(default, setter(strip_option))]
#[builder(setter(strip_option))]
root_store: Option<RootCertStore>,
}

View File

@@ -1,8 +1,6 @@
use std::{error::Error, fmt};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use crate::transcript_internal::commit::encoding::EncodingError;
use std::{error::Error, fmt};
/// Error for [`Verifier`](crate::Verifier).
#[derive(Debug, thiserror::Error)]
@@ -22,13 +20,6 @@ impl VerifierError {
}
}
pub(crate) fn config<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Config, source)
}
pub(crate) fn mpc<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
@@ -42,13 +33,6 @@ impl VerifierError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn verify<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Verify, source)
}
}
#[derive(Debug)]
@@ -58,7 +42,6 @@ enum ErrorKind {
Mpc,
Zk,
Commit,
Verify,
}
impl fmt::Display for VerifierError {
@@ -71,7 +54,6 @@ impl fmt::Display for VerifierError {
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Verify => f.write_str("verification error")?,
}
if let Some(source) = &self.source {
@@ -112,8 +94,14 @@ impl From<MpcTlsError> for VerifierError {
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
impl From<ZkAesCtrError> for VerifierError {
fn from(e: ZkAesCtrError) -> Self {
Self::new(ErrorKind::Zk, e)
}
}
impl From<CommitError> for VerifierError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -2,10 +2,15 @@
use std::sync::Arc;
use crate::mux::{MuxControl, MuxFuture};
use crate::{
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use tlsn_core::transcript::TlsTranscript;
use mpz_memory_core::correlated::Delta;
use tlsn_core::{connection::ServerName, transcript::TlsTranscript};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -23,7 +28,10 @@ opaque_debug::implement!(Initialized);
pub struct Setup {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) mpc_tls: MpcTlsFollower,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) vm: Arc<Mutex<Deap<Mpc, Zk>>>,
}
@@ -32,10 +40,16 @@ pub struct Setup {
pub struct Committed {
pub(crate) mux_ctrl: MuxControl,
pub(crate) mux_fut: MuxFuture,
pub(crate) delta: Delta,
pub(crate) ctx: Context,
pub(crate) vm: Zk,
pub(crate) keys: SessionKeys,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) verified_server_name: Option<ServerName>,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -1,179 +0,0 @@
use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
ProveRequest, VerifierOutput,
transcript::{
ContentType, Direction, PartialTranscript, Record, TlsTranscript, TranscriptCommitment,
},
webpki::ServerCertVerifier,
};
use crate::{
transcript_internal::{
TranscriptRefs,
auth::verify_plaintext,
commit::{
encoding::{self, KeyStore},
hash::verify_hash,
},
},
verifier::VerifierError,
};
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
ctx: &mut Context,
vm: &mut T,
keys: &SessionKeys,
cert_verifier: &ServerCertVerifier,
tls_transcript: &TlsTranscript,
request: ProveRequest,
) -> Result<VerifierOutput, VerifierError> {
let ProveRequest {
handshake,
transcript,
transcript_commit,
} = request;
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
let has_reveal = transcript.is_some();
let transcript = if let Some(transcript) = transcript {
if transcript.len_sent() != ciphertext_sent.len()
|| transcript.len_received() != ciphertext_recv.len()
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
transcript
} else {
PartialTranscript::new(ciphertext_sent.len(), ciphertext_recv.len())
};
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
cert_verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
let (mut commit_sent, mut commit_recv) = (RangeSet::default(), RangeSet::default());
if let Some(commit_config) = transcript_commit.as_ref() {
commit_config
.iter_hash()
.for_each(|(direction, idx, _)| match direction {
Direction::Sent => commit_sent.union_mut(idx),
Direction::Received => commit_recv.union_mut(idx),
});
if let Some((sent, recv)) = commit_config.encoding() {
commit_sent.union_mut(sent);
commit_recv.union_mut(recv);
}
}
let (sent_refs, sent_proof) = verify_plaintext(
vm,
keys.client_write_key,
keys.client_write_iv,
transcript.sent_unsafe(),
&ciphertext_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
transcript.sent_authed(),
&commit_sent,
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = verify_plaintext(
vm,
keys.server_write_key,
keys.server_write_iv,
transcript.received_unsafe(),
&ciphertext_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
transcript.received_authed(),
&commit_recv,
)
.map_err(VerifierError::zk)?;
let transcript_refs = TranscriptRefs {
sent: sent_refs,
recv: recv_refs,
};
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit.as_ref()
&& commit_config.has_hash()
{
hash_commitments = Some(
verify_hash(vm, &transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
sent_proof.verify().map_err(VerifierError::verify)?;
recv_proof.verify().map_err(VerifierError::verify)?;
let mut encoder_secret = None;
if let Some(commit_config) = transcript_commit
&& let Some((sent, recv)) = commit_config.encoding()
{
let sent_map = transcript_refs
.sent
.index(sent)
.expect("ranges were authenticated");
let recv_map = transcript_refs
.recv
.index(recv)
.expect("ranges were authenticated");
let (secret, commitment) = encoding::transfer(ctx, vm, &sent_map, &recv_map).await?;
encoder_secret = Some(secret);
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript: has_reveal.then_some(transcript),
encoder_secret,
transcript_commitments,
})
}
fn collect_ciphertext<'a>(records: impl IntoIterator<Item = &'a Record>) -> Vec<u8> {
let mut ciphertext = Vec::new();
records
.into_iter()
.filter(|record| record.typ == ContentType::ApplicationData)
.for_each(|record| {
ciphertext.extend_from_slice(&record.ciphertext);
});
ciphertext
}

View File

@@ -0,0 +1,214 @@
//! Zero-knowledge AES-CTR encryption.
use cipher::{
Cipher, CipherError, Keystream,
aes::{Aes128, AesError},
};
use mpz_memory_core::{
Array, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, prelude::*};
use crate::Role;
type Nonce = Array<U8, 8>;
type Ctr = Array<U8, 4>;
type Block = Array<U8, 16>;
const START_CTR: u32 = 2;
/// ZK AES-CTR encryption.
#[derive(Debug)]
pub(crate) struct ZkAesCtr {
role: Role,
aes: Aes128,
state: State,
}
impl ZkAesCtr {
/// Creates a new ZK AES-CTR instance.
pub(crate) fn new(role: Role) -> Self {
Self {
role,
aes: Aes128::default(),
state: State::Init,
}
}
/// Returns the role.
pub(crate) fn role(&self) -> Role {
self.role
}
/// Allocates `len` bytes for encryption.
pub(crate) fn alloc(
&mut self,
vm: &mut dyn Vm<Binary>,
len: usize,
) -> Result<(), ZkAesCtrError> {
let State::Init = self.state.take() else {
Err(ErrorRepr::State {
reason: "must be in init state to allocate",
})?
};
// Round up to the nearest block size.
let len = 16 * len.div_ceil(16);
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
let keystream = self.aes.alloc_keystream(vm, len)?;
match self.role {
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
}
self.state = State::Ready { input, keystream };
Ok(())
}
/// Sets the key and IV for the cipher.
pub(crate) fn set_key(&mut self, key: Array<U8, 16>, iv: Array<U8, 4>) {
self.aes.set_key(key);
self.aes.set_iv(iv);
}
/// Proves the encryption of `len` bytes.
///
/// Here we only assign certain values in the VM but the actual proving
/// happens later when the plaintext is assigned and the VM is executed.
///
/// # Arguments
///
/// * `vm` - Virtual machine.
/// * `explicit_nonce` - Explicit nonce.
/// * `len` - Length of the plaintext in bytes.
///
/// # Returns
///
/// A VM reference to the plaintext and the ciphertext.
pub(crate) fn encrypt(
&mut self,
vm: &mut dyn Vm<Binary>,
explicit_nonce: Vec<u8>,
len: usize,
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
let State::Ready { input, keystream } = &mut self.state else {
Err(ErrorRepr::State {
reason: "must be in ready state to encrypt",
})?
};
let explicit_nonce: [u8; 8] =
explicit_nonce
.try_into()
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
})?;
let block_count = len.div_ceil(16);
let padded_len = block_count * 16;
let padding_len = padded_len - len;
if padded_len > input.len() {
Err(ErrorRepr::InsufficientPreprocessing {
expected: padded_len,
actual: input.len(),
})?
}
let mut input = input.split_off(input.len() - padded_len);
let keystream = keystream.consume(padded_len)?;
let mut output = keystream.apply(vm, input)?;
// Assign counter block inputs.
let mut ctr = START_CTR..;
keystream.assign(vm, explicit_nonce, move || {
ctr.next().expect("range is unbounded").to_be_bytes()
})?;
// Assign zeroes to the padding.
if padding_len > 0 {
let padding = input.split_off(input.len() - padding_len);
// To simplify the impl, we don't mark the padding as public, that's why only
// the prover assigns it.
if let Role::Prover = self.role {
vm.assign(padding, vec![0; padding_len])
.map_err(ZkAesCtrError::vm)?;
}
vm.commit(padding).map_err(ZkAesCtrError::vm)?;
output.truncate(len);
}
Ok((input, output))
}
}
enum State {
Init,
Ready {
input: Vector<U8>,
keystream: Keystream<Nonce, Ctr, Block>,
},
Error,
}
impl State {
fn take(&mut self) -> Self {
std::mem::replace(self, State::Error)
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
State::Init => write!(f, "Init"),
State::Ready { .. } => write!(f, "Ready"),
State::Error => write!(f, "Error"),
}
}
}
/// Error for [`ZkAesCtr`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ZkAesCtrError(#[from] ErrorRepr);
impl ZkAesCtrError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
#[error("zk aes error")]
enum ErrorRepr {
#[error("invalid state: {reason}")]
State { reason: &'static str },
#[error("cipher error: {0}")]
Cipher(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
ExplicitNonceLength { expected: usize, actual: usize },
#[error("insufficient preprocessing: expected {expected}, got {actual}")]
InsufficientPreprocessing { expected: usize, actual: usize },
}
impl From<AesError> for ZkAesCtrError {
fn from(err: AesError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}
impl From<CipherError> for ZkAesCtrError {
fn from(err: CipherError) -> Self {
Self(ErrorRepr::Cipher(Box::new(err)))
}
}

View File

@@ -1,17 +1,11 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::RangeSet;
use tlsn::{
config::{CertificateDer, ProtocolConfig, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::{HashAlgId, HashProvider},
prover::{ProveConfig, Prover, ProverConfig, TlsConfig},
transcript::{
Direction, Transcript, TranscriptCommitConfig, TranscriptCommitment,
TranscriptCommitmentKind, TranscriptSecret,
},
transcript::{TranscriptCommitConfig, TranscriptCommitment},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_core::ProverOutput;
use tlsn_server_fixture::bind;
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
@@ -35,80 +29,11 @@ async fn test() {
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
let ((full_transcript, prover_output), verifier_output) =
tokio::join!(prover(socket_0), verifier(socket_1));
let partial_transcript = verifier_output.transcript.unwrap();
let ServerName::Dns(server_name) = verifier_output.server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete());
assert_eq!(
partial_transcript
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);
assert_eq!(
partial_transcript
.received_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);
let encoding_tree = prover_output
.transcript_secrets
.iter()
.find_map(|secret| {
if let TranscriptSecret::Encoding(tree) = secret {
Some(tree)
} else {
None
}
})
.unwrap();
let encoding_commitment = prover_output
.transcript_commitments
.iter()
.find_map(|commitment| {
if let TranscriptCommitment::Encoding(commitment) = commitment {
Some(commitment)
} else {
None
}
})
.unwrap();
let prove_sent = RangeSet::from(1..full_transcript.sent().len() - 1);
let prove_recv = RangeSet::from(1..full_transcript.received().len() - 1);
let idxs = [
(Direction::Sent, prove_sent.clone()),
(Direction::Received, prove_recv.clone()),
];
let proof = encoding_tree.proof(idxs.iter()).unwrap();
let (auth_sent, auth_recv) = proof
.verify_with_provider(
&HashProvider::default(),
&verifier_output.encoder_secret.unwrap(),
encoding_commitment,
full_transcript.sent(),
full_transcript.received(),
)
.unwrap();
assert_eq!(auth_sent, prove_sent);
assert_eq!(auth_recv, prove_recv);
tokio::join!(prover(socket_0), verifier(socket_1));
}
#[instrument(skip(verifier_socket))]
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
) -> (Transcript, ProverOutput) {
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_socket: T) {
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
let server_task = tokio::spawn(bind(server_socket.compat()));
@@ -161,25 +86,9 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let mut builder = TranscriptCommitConfig::builder(prover.transcript());
for kind in [
TranscriptCommitmentKind::Encoding,
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
] {
builder
.commit_with_kind(&(0..sent_tx_len), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(0..recv_tx_len), Direction::Received, kind)
.unwrap();
builder
.commit_with_kind(&(1..sent_tx_len - 1), Direction::Sent, kind)
.unwrap();
builder
.commit_with_kind(&(1..recv_tx_len - 1), Direction::Received, kind)
.unwrap();
}
// Commit to everything
builder.commit_sent(&(0..sent_tx_len)).unwrap();
builder.commit_recv(&(0..recv_tx_len)).unwrap();
let transcript_commit = builder.build().unwrap();
@@ -193,17 +102,13 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
builder.transcript_commit(transcript_commit);
let config = builder.build().unwrap();
let transcript = prover.transcript().clone();
let output = prover.prove(&config).await.unwrap();
prover.close().await.unwrap();
(transcript, output)
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}
#[instrument(skip(socket))]
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
) -> VerifierOutput {
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(socket: T) {
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
@@ -220,16 +125,31 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
.unwrap(),
);
let mut verifier = verifier
.setup(socket.compat())
.await
.unwrap()
.run()
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await
.unwrap();
let output = verifier.verify(&VerifyConfig::default()).await.unwrap();
verifier.close().await.unwrap();
let transcript = transcript.unwrap();
output
let ServerName::Dns(server_name) = server_name.unwrap();
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!transcript.is_complete());
assert_eq!(
transcript.sent_authed().iter_ranges().next().unwrap(),
0..10
);
assert_eq!(
transcript.received_authed().iter_ranges().next().unwrap(),
0..10
);
assert!(matches!(
transcript_commitments[0],
TranscriptCommitment::Encoding(_)
));
}

View File

@@ -6,15 +6,11 @@ build-std = ["panic_abort", "std"]
[target.wasm32-unknown-unknown]
rustflags = [
"-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-Clink-arg=--shared-memory",
"-C",
"target-feature=+atomics,+bulk-memory,+mutable-globals,+simd128",
"-C",
# 4GB
"-Clink-arg=--max-memory=4294967296",
"-Clink-arg=--import-memory",
"-Clink-arg=--export=__wasm_init_tls",
"-Clink-arg=--export=__tls_size",
"-Clink-arg=--export=__tls_align",
"-Clink-arg=--export=__tls_base",
"link-arg=--max-memory=4294967296",
"--cfg",
'getrandom_backend="wasm_js"',
]

View File

@@ -1,6 +1,6 @@
[package]
name = "tlsn-wasm"
version = "0.1.0-alpha.13"
version = "0.1.0-alpha.13-pre"
edition = "2021"
repository = "https://github.com/tlsnotary/tlsn.git"
description = "A core WebAssembly package for TLSNotary."

View File

@@ -126,7 +126,7 @@ impl JsProver {
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
info!("Finalized");