Compare commits

..

10 Commits

Author SHA1 Message Date
sinu
a56f2dd02b rename ProvePayload to ProveRequest 2025-09-11 12:51:23 -07:00
sinu
38a1ec3f72 import order 2025-09-11 09:43:38 -07:00
sinu
1501bc661f consolidate encoding stuff 2025-09-11 09:42:03 -07:00
sinu
0e2c2cb045 revert additional derives 2025-09-11 09:37:44 -07:00
sinu
a662fb7511 fix import order 2025-09-11 09:32:50 -07:00
th4s
be2e1ab95a adapt tests to new base 2025-09-11 18:15:14 +02:00
th4s
d34d135bfe adapt to new base 2025-09-11 10:55:26 +02:00
th4s
091d26bb63 fmt nightly 2025-09-11 10:13:08 +02:00
th4s
f031fe9a8d allow encoding protocol being executed only once 2025-09-11 10:13:08 +02:00
th4s
12c9a5eb34 refactor(tlsn): improve proving flow prover and verifier
- add `TlsTranscript` fixture for testing
- refactor and simplify the commit flow
- change how the TLS transcript is committed
  - tag verification is still done for the whole transcript
  - plaintext authentication is only done where needed
  - encoding adjustments are only transmitted for needed ranges
  - makes `prove` and `verify` functions more efficient and callable more than once
- decoding now supports server-write-key decoding without authentication
- add more tests
2025-09-11 10:13:06 +02:00
54 changed files with 4132 additions and 3815 deletions

View File

@@ -18,10 +18,10 @@ env:
# We need a higher number of parallel rayon tasks than the default (which is 4)
# in order to prevent a deadlock, c.f.
# - https://github.com/tlsnotary/tlsn/issues/548
# - https://github.com/privacy-ethereum/mpz/issues/178
# - https://github.com/privacy-scaling-explorations/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.90.0
RUST_VERSION: 1.89.0
jobs:
clippy:
@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4
- name: Install rust toolchain
uses: dtolnay/rust-toolchain@master
uses: dtolnay/rust-toolchain@stable
with:
toolchain: ${{ env.RUST_VERSION }}
components: clippy

View File

@@ -23,6 +23,7 @@ jobs:
- name: "rustdoc"
run: crates/wasm/build-docs.sh
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
if: ${{ github.ref == 'refs/heads/dev' }}

2108
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "70348c1" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "3d90b6c" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
@@ -100,7 +100,7 @@ bytes = { version = "1.4" }
cfg-if = { version = "1" }
chromiumoxide = { version = "0.7" }
chrono = { version = "0.4" }
cipher = { version = "0.4" }
cipher-crypto = { package = "cipher", version = "0.4" }
clap = { version = "4.5" }
criterion = { version = "0.5" }
ctr = { version = "0.9" }
@@ -110,7 +110,7 @@ elliptic-curve = { version = "0.13" }
enum-try-as-inner = { version = "0.1" }
env_logger = { version = "0.10" }
futures = { version = "0.3" }
futures-rustls = { version = "0.25" }
futures-rustls = { version = "0.26" }
generic-array = { version = "0.14" }
ghash = { version = "0.5" }
hex = { version = "0.4" }
@@ -123,6 +123,7 @@ inventory = { version = "0.3" }
itybity = { version = "0.2" }
js-sys = { version = "0.3" }
k256 = { version = "0.13" }
lipsum = { version = "0.9" }
log = { version = "0.4" }
once_cell = { version = "1.19" }
opaque-debug = { version = "0.3" }

View File

@@ -23,9 +23,9 @@ thiserror = { workspace = true }
tiny-keccak = { workspace = true, features = ["keccak"] }
[dev-dependencies]
alloy-primitives = { version = "1.3.1", default-features = false }
alloy-signer = { version = "1.0", default-features = false }
alloy-signer-local = { version = "1.0", default-features = false }
alloy-primitives = { version = "0.8.22", default-features = false }
alloy-signer = { version = "0.12", default-features = false }
alloy-signer-local = { version = "0.12", default-features = false }
rand06-compat = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }

View File

@@ -31,4 +31,4 @@ mpz-ot = { workspace = true }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
rand = { workspace = true }
ctr = { workspace = true }
cipher = { workspace = true }
cipher-crypto = { workspace = true }

View File

@@ -344,8 +344,8 @@ mod tests {
start_ctr: usize,
msg: Vec<u8>,
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let mut full_iv = [0u8; 16];
@@ -365,7 +365,7 @@ mod tests {
fn aes128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
use ::aes::Aes128 as TestAes128;
use ::cipher::{BlockEncrypt, KeyInit};
use cipher_crypto::{BlockEncrypt, KeyInit};
let mut msg = msg.into();
let cipher = TestAes128::new(&key.into());

View File

@@ -6,7 +6,10 @@ use rustls_pki_types as webpki_types;
use serde::{Deserialize, Serialize};
use tls_core::msgs::{codec::Codec, enums::NamedGroup, handshake::ServerECDHParams};
use crate::webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError};
use crate::{
transcript::TlsTranscript,
webpki::{CertificateDer, ServerCertVerifier, ServerCertVerifierError},
};
/// TLS version.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -312,6 +315,25 @@ pub struct HandshakeData {
}
impl HandshakeData {
/// Creates a new instance.
///
/// # Arguments
///
/// * `transcript` - The TLS transcript.
pub fn new(transcript: &TlsTranscript) -> Self {
Self {
certs: transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: transcript.certificate_binding().clone(),
}
}
/// Verifies the handshake data.
///
/// # Arguments

View File

@@ -95,7 +95,7 @@ impl Display for HashAlgId {
}
/// A typed hash value.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TypedHash {
/// The algorithm of the hash.
pub alg: HashAlgId,
@@ -191,11 +191,6 @@ impl Hash {
len: value.len(),
}
}
/// Returns a byte slice of the hash value.
pub fn as_bytes(&self) -> &[u8] {
&self.value[..self.len]
}
}
impl rs_merkle::Hash for Hash {

View File

@@ -130,6 +130,15 @@ impl<'a> ProveConfigBuilder<'a> {
self.reveal(Direction::Received, ranges)
}
/// Reveals the full transcript range for a given direction.
pub fn reveal_all(
&mut self,
direction: Direction,
) -> Result<&mut Self, ProveConfigBuilderError> {
let len = self.transcript.len_of_direction(direction);
self.reveal(direction, &(0..len))
}
/// Builds the configuration.
pub fn build(self) -> Result<ProveConfig, ProveConfigBuilderError> {
Ok(ProveConfig {
@@ -190,10 +199,10 @@ pub struct VerifyConfigBuilderError(#[from] VerifyConfigBuilderErrorRepr);
#[derive(Debug, thiserror::Error)]
enum VerifyConfigBuilderErrorRepr {}
/// Payload sent to the verifier.
/// Request to prove statements about the connection.
#[doc(hidden)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ProvePayload {
pub struct ProveRequest {
/// Handshake data.
pub handshake: Option<(ServerName, HandshakeData)>,
/// Transcript data.
@@ -202,6 +211,29 @@ pub struct ProvePayload {
pub transcript_commit: Option<TranscriptCommitRequest>,
}
impl ProveRequest {
/// Creates a new prove payload.
///
/// # Arguments
///
/// * `config` - The prove config.
/// * `transcript` - The partial transcript.
/// * `handshake` - The server name and handshake data.
pub fn new(
config: &ProveConfig,
transcript: Option<PartialTranscript>,
handshake: Option<(ServerName, HandshakeData)>,
) -> Self {
let transcript_commit = config.transcript_commit().map(|config| config.to_request());
Self {
handshake,
transcript,
transcript_commit,
}
}
}
/// Prover output.
#[derive(Serialize, Deserialize)]
pub struct ProverOutput {

View File

@@ -1,6 +1,6 @@
//! Transcript commitments.
use std::{collections::HashSet, fmt};
use std::fmt;
use rangeset::ToRangeSet;
use serde::{Deserialize, Serialize};
@@ -69,8 +69,6 @@ pub enum TranscriptSecret {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitConfig {
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
@@ -85,16 +83,6 @@ impl TranscriptCommitConfig {
&self.encoding_hash_alg
}
/// Returns `true` if the configuration has any encoding commitments.
pub fn has_encoding(&self) -> bool {
self.has_encoding
}
/// Returns `true` if the configuration has any hash commitments.
pub fn has_hash(&self) -> bool {
self.has_hash
}
/// Returns an iterator over the encoding commitment indices.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.commits.iter().filter_map(|(idx, kind)| match kind {
@@ -114,7 +102,10 @@ impl TranscriptCommitConfig {
/// Returns a request for the transcript commitments.
pub fn to_request(&self) -> TranscriptCommitRequest {
TranscriptCommitRequest {
encoding: self.has_encoding,
encoding: self
.iter_encoding()
.map(|(dir, idx)| (*dir, idx.clone()))
.collect(),
hash: self
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg))
@@ -131,10 +122,8 @@ impl TranscriptCommitConfig {
pub struct TranscriptCommitConfigBuilder<'a> {
transcript: &'a Transcript,
encoding_hash_alg: HashAlgId,
has_encoding: bool,
has_hash: bool,
default_kind: TranscriptCommitmentKind,
commits: HashSet<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
commits: Vec<((Direction, RangeSet<usize>), TranscriptCommitmentKind)>,
}
impl<'a> TranscriptCommitConfigBuilder<'a> {
@@ -143,10 +132,8 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
Self {
transcript,
encoding_hash_alg: HashAlgId::BLAKE3,
has_encoding: false,
has_hash: false,
default_kind: TranscriptCommitmentKind::Encoding,
commits: HashSet::default(),
commits: Vec::default(),
}
}
@@ -188,14 +175,12 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
),
));
}
let value = ((direction, idx), kind);
match kind {
TranscriptCommitmentKind::Encoding => self.has_encoding = true,
TranscriptCommitmentKind::Hash { .. } => self.has_hash = true,
if !self.commits.contains(&value) {
self.commits.push(value);
}
self.commits.insert(((direction, idx), kind));
Ok(self)
}
@@ -241,8 +226,6 @@ impl<'a> TranscriptCommitConfigBuilder<'a> {
pub fn build(self) -> Result<TranscriptCommitConfig, TranscriptCommitConfigBuilderError> {
Ok(TranscriptCommitConfig {
encoding_hash_alg: self.encoding_hash_alg,
has_encoding: self.has_encoding,
has_hash: self.has_hash,
commits: Vec::from_iter(self.commits),
})
}
@@ -289,19 +272,14 @@ impl fmt::Display for TranscriptCommitConfigBuilderError {
/// Request to compute transcript commitments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptCommitRequest {
encoding: bool,
encoding: Vec<(Direction, RangeSet<usize>)>,
hash: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
}
impl TranscriptCommitRequest {
/// Returns `true` if an encoding commitment is requested.
pub fn encoding(&self) -> bool {
self.encoding
}
/// Returns `true` if a hash commitment is requested.
pub fn has_hash(&self) -> bool {
!self.hash.is_empty()
/// Returns an iterator over the encoding commitments.
pub fn iter_encoding(&self) -> impl Iterator<Item = &(Direction, RangeSet<usize>)> {
self.encoding.iter()
}
/// Returns an iterator over the hash commitments.

View File

@@ -12,7 +12,7 @@ const BIT_ENCODING_SIZE: usize = 16;
const BYTE_ENCODING_SIZE: usize = 128;
/// Secret used by an encoder to generate encodings.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EncoderSecret {
seed: [u8; 32],
delta: [u8; BIT_ENCODING_SIZE],

View File

@@ -22,9 +22,6 @@ const DEFAULT_COMMITMENT_KINDS: &[TranscriptCommitmentKind] = &[
TranscriptCommitmentKind::Hash {
alg: HashAlgId::SHA256,
},
TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
},
TranscriptCommitmentKind::Encoding,
];
@@ -640,9 +637,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
fn test_reveal_with_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -650,6 +645,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {
@@ -687,9 +683,7 @@ mod tests {
}
#[rstest]
#[case::sha256(HashAlgId::SHA256)]
#[case::blake3(HashAlgId::BLAKE3)]
fn test_reveal_with_inconsistent_hash_commitment(#[case] alg: HashAlgId) {
fn test_reveal_with_inconsistent_hash_commitment() {
let mut rng = rand::rngs::StdRng::seed_from_u64(0);
let provider = HashProvider::default();
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
@@ -697,6 +691,7 @@ mod tests {
let direction = Direction::Sent;
let idx = RangeSet::from(0..10);
let blinder: Blinder = rng.random();
let alg = HashAlgId::SHA256;
let hasher = provider.get(&alg).unwrap();
let commitment = PlaintextHash {

View File

@@ -24,7 +24,6 @@ hex = { workspace = true }
hyper = { workspace = true, features = ["client", "http1"] }
hyper-util = { workspace = true, features = ["full"] }
k256 = { workspace = true, features = ["ecdsa"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = [
"rt",
@@ -37,17 +36,11 @@ tokio = { workspace = true, features = [
tokio-util = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
noir = { git = "https://github.com/zkmopro/noir-rs", tag = "v1.0.0-beta.8", features = ["barretenberg"] }
blake3 = { workspace = true }
[[example]]
name = "interactive"
path = "interactive/interactive.rs"
[[example]]
name = "interactive_zk"
path = "interactive_zk/interactive_zk.rs"
[[example]]
name = "attestation_prove"
path = "attestation/prove.rs"

View File

@@ -255,7 +255,7 @@ async fn notarize(
transcript_commitments,
transcript_secrets,
..
} = prover.prove(&disclosure_config).await?;
} = prover.prove(disclosure_config).await?;
// Build an attestation request.
let mut builder = AttestationRequest::builder(config);

View File

@@ -174,7 +174,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -1,5 +0,0 @@
!noir/target/
# Ignore everything inside noir/target
noir/target/*
# Except noir.json
!noir/target/noir.json

View File

@@ -1,167 +0,0 @@
# Interactive Zero-Knowledge Age Verification with TLSNotary
This example demonstrates **privacy-preserving age verification** using TLSNotary and zero-knowledge proofs. It allows a prover to demonstrate they are 18+ years old without revealing their actual birth date or any other personal information.
## 🔍 How It Works (simplified overview)
```mermaid
sequenceDiagram
participant S as Tax Server<br/>(fixture)
participant P as Prover
participant V as Verifier
P->>S: Request tax data (with auth token) (MPC-TLS)
S->>P: Tax data including `date_of_birth` (MPC-TLS)
P->>V: Share transcript with redactions
P->>V: Commit to blinded hash of birth date
P->>P: Generate ZK proof of age ≥ 18
P->>V: Send ZK proof
V->>V: Verify transcript & ZK proof
V->>V: ✅ Confirm: Prover is 18+ (no birth date revealed)
```
### The Process
1. **MPC-TLS Session**: The Prover fetches tax information containing their birth date, while the Verifier jointly verifies the TLS session to ensure the data comes from the authentic server.
2. **Selective Disclosure**:
* The authorization token is **redacted**: the Verifier sees the plaintext request but not the token.
* The birth date is **committed** as a blinded hash: the Verifier cannot see the date, but the Prover is cryptographically bound to it.
(Depending on the use case more data can be redacted or revealed)
3. **Zero-Knowledge Proof**: The Prover generates a ZK proof that the committed birth date corresponds to an age ≥ 18.
4. **Verification**: The Verifier checks both the TLS transcript and the ZK proof, confirming age ≥ 18 without learning the actual date of birth.
### Example Data
The tax server returns data like this:
```json
{
"tax_year": 2024,
"taxpayer": {
"idnr": "12345678901",
"first_name": "Max",
"last_name": "Mustermann",
"date_of_birth": "1985-03-12",
// ...
}
}
```
## 🔐 Zero-Knowledge Proof Details
The ZK circuit proves: **"I know a birth date that hashes to the committed value AND indicates I am 18+ years old"**
**Public Inputs:**
- ✅ Verification date
- ✅ Committed blinded hash of birth date
**Private Inputs (Hidden):**
- 🔒 Actual birth date plaintext
- 🔒 Random blinder used in hash commitment
**What the Verifier Learns:**
- ✅ The prover is 18+ years old
- ✅ The birth date is authentic (from the MPC-TLS session)
Everything else remains private.
## 🏃 Run the Example
1. **Start the test server** (from repository root):
```bash
RUST_LOG=info PORT=4000 cargo run --bin tlsn-server-fixture
```
2. **Run the age verification** (in a new terminal):
```bash
SERVER_PORT=4000 cargo run --release --example interactive_zk
```
3. **For detailed logs**:
```bash
RUST_LOG=debug,yamux=info,uid_mux=info SERVER_PORT=4000 cargo run --release --example interactive_zk
```
### Expected Output
```
Successfully verified https://test-server.io:4000/elster
Age verified in ZK: 18+ ✅
Verified sent data:
GET https://test-server.io:4000/elster HTTP/1.1
host: test-server.io
connection: close
authorization: 🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈🙈
Verified received data:
🙈🙈🙈🙈🙈🙈🙈🙈[truncated for brevity]...🙈🙈🙈🙈🙈"tax_year":2024🙈🙈🙈🙈🙈...
```
> 💡 **Note**: In this demo, both Prover and Verifier run on the same machine. In production, they would operate on separate systems.
> 💡 **Note**: This demo assumes that the tax server serves correct data, and that only the submitter of the tax data has access to the specified page.
## 🛠 Development
### Project Structure
```
interactive_zk/
├── prover.rs # Prover implementation
├── verifier.rs # Verifier implementation
├── types.rs # Shared types
└── interactive_zk.rs # Main example runner
├── noir/ # Zero-knowledge circuit
│ ├── src/main.n # Noir circuit code
│ ├── target/ # Compiled circuit artifacts
│ └── Nargo.toml # Noir project config
│ └── Prover.toml # Example input for `nargo execute`
│ └── generate_test_data.rs # Rust script to generate Noir test data
└── README.md
```
### Noir Circuit Commands
We use [Mopro's `noir_rs`](https://zkmopro.org/docs/crates/noir-rs/) for ZK proof generation. The **circuit is pre-compiled and ready to use**. You don't need to install Noir tools to run the example. But if you want to change or test the circuit in isolation, you can use the following instructions.
Before you proceed, we recommend to double check that your Noir tooling matches the versions used in Mopro's `noir_rs`:
```sh
# Install correct Noir and BB versions (important for compatibility!)
noirup --version 1.0.0-beta.8
bbup -v 1.0.0-nightly.20250723
```
If you don't have `noirup` and `bbup` installed yet, check [Noir's Quick Start](https://noir-lang.org/docs/getting_started/quick_start).
To compile the circuit, go to the `noir` folder and run `nargo compile`.
To check and experiment with the Noir circuit, you can use these commands:
* Execute Circuit: Compile the circuit and run it with sample data from `Prover.toml`:
```sh
nargo execute
```
* Generate Verification Key: Create the verification key needed to verify proofs
```sh
bb write_vk -b ./target/noir.json -o ./target
```
* Generate Proof: Create a zero-knowledge proof using the circuit and witness data.
```sh
bb prove --bytecode_path ./target/noir.json --witness_path ./target/noir.gz -o ./target
```
* Verify Proof: Verify that a proof is valid using the verification key.
```sh
bb verify -k ./target/vk -p ./target/proof
```
* Run the Noir tests:
```sh
nargo test --show-output
```
To create extra tests, you can use `./generate_test_data.rs` to help with generating correct blinders and hashes.
## 📚 Learn More
- [TLSNotary Documentation](https://docs.tlsnotary.org/)
- [Noir Language Guide](https://noir-lang.org/)
- [Zero-Knowledge Proofs Explained](https://ethereum.org/en/zero-knowledge-proofs/)
- [Mopro ZK Toolkit](https://zkmopro.org/)

View File

@@ -1,59 +0,0 @@
mod prover;
mod types;
mod verifier;
use prover::prover;
use std::{
env,
net::{IpAddr, SocketAddr},
};
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
use tlsn_server_fixture_certs::SERVER_DOMAIN;
use verifier::verifier;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
let server_port: u16 = env::var("SERVER_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(DEFAULT_FIXTURE_PORT);
// We use SERVER_DOMAIN here to make sure it matches the domain in the test
// server's certificate.
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/elster");
let server_ip: IpAddr = server_host
.parse()
.map_err(|e| format!("Invalid IP address '{}': {}", server_host, e))?;
let server_addr = SocketAddr::from((server_ip, server_port));
// Connect prover and verifier.
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
let (prover_extra_socket, verifier_extra_socket) = tokio::io::duplex(1 << 23);
let (_, transcript) = tokio::try_join!(
prover(prover_socket, prover_extra_socket, &server_addr, &uri),
verifier(verifier_socket, verifier_extra_socket)
)?;
println!("---");
println!("Successfully verified {}", &uri);
println!("Age verified in ZK: 18+ ✅\n");
println!(
"Verified sent data:\n{}",
bytes_to_redacted_string(transcript.sent_unsafe())
);
println!(
"Verified received data:\n{}",
bytes_to_redacted_string(transcript.received_unsafe())
);
Ok(())
}
/// Render redacted bytes as `🙈`.
pub fn bytes_to_redacted_string(bytes: &[u8]) -> String {
String::from_utf8_lossy(bytes).replace('\0', "🙈")
}

View File

@@ -1,7 +0,0 @@
[package]
name = "noir"
type = "bin"
authors = [""]
[dependencies]
date = { tag = "v0.5.4", git = "https://github.com/madztheo/noir-date.git" }

View File

@@ -1,8 +0,0 @@
blinder = [108, 93, 120, 205, 15, 35, 159, 124, 243, 96, 22, 128, 16, 149, 219, 216]
committed_hash = [186, 158, 101, 39, 49, 48, 26, 83, 242, 96, 10, 221, 121, 174, 62, 50, 136, 132, 232, 58, 25, 32, 66, 196, 99, 85, 66, 85, 255, 1, 202, 254]
date_of_birth = "1985-03-12"
[proof_date]
day = "29"
month = "08"
year = "2025"

View File

@@ -1,64 +0,0 @@
#!/usr/bin/env -S cargo +nightly -Zscript
---
[package]
name = "generate_test_data"
version = "0.0.0"
edition = "2021"
publish = false
[dependencies]
blake3 = "1.5"
rand = "0.8"
chrono = "0.4"
---
use chrono::Datelike;
use chrono::Local;
use rand::RngCore;
fn main() {
// 1. Birthdate string (fixed)
let dob_str = "1985-03-12"; // 10 bytes long
let proof_date = Local::now().date_naive();
let proof_year = proof_date.year();
let proof_month = proof_date.month();
let proof_day = proof_date.day();
// 2. Generate random 16-byte blinder
let mut blinder = [0u8; 16];
rand::thread_rng().fill_bytes(&mut blinder);
// 3. Concatenate blinder + dob string bytes
let mut preimage = Vec::with_capacity(26);
preimage.extend_from_slice(dob_str.as_bytes());
preimage.extend_from_slice(&blinder);
// 4. Hash it
let hash = blake3::hash(&preimage);
let hash = hash.as_bytes();
let blinder = blinder
.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join(", ");
let committed_hash = hash
.iter()
.map(|b| b.to_string())
.collect::<Vec<_>>()
.join(", ");
println!(
"
// Private input
let date_of_birth = \"{dob_str}\";
let blinder = [{blinder}];
// Public input
let proof_date = date::Date {{ year: {proof_year}, month: {proof_month}, day: {proof_day} }};
let committed_hash = [{committed_hash}];
main(proof_date, committed_hash, date_of_birth, blinder);
"
);
}

View File

@@ -1,82 +0,0 @@
use dep::date::Date;
fn main(
// Public inputs
proof_date: pub date::Date, // "2025-08-29"
committed_hash: pub [u8; 32], // Hash of (blinder || dob string)
// Private inputs
date_of_birth: str<10>, // "1985-03-12"
blinder: [u8; 16], // Random 16-byte blinder
) {
let is_18 = check_18(date_of_birth, proof_date);
let correct_hash = check_hash(date_of_birth, blinder, committed_hash);
assert(correct_hash);
assert(is_18);
}
fn check_18(date_of_birth: str<10>, proof_date: date::Date) -> bool {
let dob = parse_birth_date(date_of_birth);
let is_18 = dob.add_years(18).lt(proof_date);
println(f"Is 18? {is_18}");
is_18
}
fn check_hash(date_of_birth: str<10>, blinder: [u8; 16], committed_hash: [u8; 32]) -> bool {
let hash_input: [u8; 26] = make_hash_input(date_of_birth, blinder);
let computed_hash = std::hash::blake3(hash_input);
let correct_hash = computed_hash == committed_hash;
println(f"Correct hash? {correct_hash}");
correct_hash
}
fn make_hash_input(dob: str<10>, blinder: [u8; 16]) -> [u8; 26] {
let mut input: [u8; 26] = [0; 26];
for i in 0..10 {
input[i] = dob.as_bytes()[i];
}
for i in 0..16 {
input[10 + i] = blinder[i];
}
input
}
pub fn parse_birth_date(birth_date: str<10>) -> date::Date {
let date: [u8; 10] = birth_date.as_bytes();
let date_str: str<8> =
[date[0], date[1], date[2], date[3], date[5], date[6], date[8], date[9]].as_str_unchecked();
Date::from_str_long_year(date_str)
}
#[test]
fn test_max_is_over_18() {
// Private input
let date_of_birth = "1985-03-12";
let blinder = [109, 224, 222, 179, 60, 44, 41, 65, 166, 94, 111, 216, 73, 231, 63, 83];
// Public input
let proof_date = date::Date { year: 2025, month: 9, day: 26 };
let committed_hash = [
114, 34, 41, 235, 91, 156, 13, 57, 254, 112, 250, 35, 104, 217, 20, 182, 240, 170, 57, 39,
187, 154, 14, 39, 91, 67, 50, 199, 149, 231, 78, 46,
];
main(proof_date, committed_hash, date_of_birth, blinder);
}
#[test(should_fail)]
fn test_under_18() {
// Private input
let date_of_birth = "2010-08-01";
let blinder = [160, 23, 57, 158, 141, 195, 155, 132, 109, 242, 48, 220, 70, 217, 229, 189];
// Public input
let proof_date = date::Date { year: 2025, month: 8, day: 29 };
let committed_hash = [
16, 132, 194, 62, 232, 90, 157, 153, 4, 231, 1, 54, 226, 3, 87, 174, 129, 177, 80, 69, 37,
222, 209, 91, 168, 156, 9, 109, 108, 144, 168, 109,
];
main(proof_date, committed_hash, date_of_birth, blinder);
}

File diff suppressed because one or more lines are too long

View File

@@ -1,379 +0,0 @@
use std::net::SocketAddr;
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use chrono::{Datelike, Local, NaiveDate};
use http_body_util::Empty;
use hyper::{body::Bytes, header, Request, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use k256::sha2::Digest;
use noir::{
barretenberg::{
prove::prove_ultra_honk, srs::setup_srs_from_bytecode,
verify::get_ultra_honk_verification_key,
},
witness::from_vec_str_to_witness_map,
};
use serde_json::Value;
use spansy::{
http::{BodyContent, Requests, Responses},
Spanned,
};
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{CertificateDer, ProtocolConfig, RootCertStore},
connection::ServerName,
hash::HashAlgId,
prover::{ProveConfig, ProveConfigBuilder, Prover, ProverConfig, TlsConfig},
transcript::{
hash::{PlaintextHash, PlaintextHashSecret},
Direction, TranscriptCommitConfig, TranscriptCommitConfigBuilder, TranscriptCommitmentKind,
TranscriptSecret,
},
};
use tlsn_examples::MAX_RECV_DATA;
use tokio::io::AsyncWriteExt;
use tlsn_examples::MAX_SENT_DATA;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
use tracing::instrument;
#[instrument(skip(verifier_socket, verifier_extra_socket))]
pub async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
verifier_socket: T,
mut verifier_extra_socket: T,
server_addr: &SocketAddr,
uri: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let uri = uri.parse::<Uri>()?;
if uri.scheme().map(|s| s.as_str()) != Some("https") {
return Err("URI must use HTTPS scheme".into());
}
let server_domain = uri.authority().ok_or("URI must have authority")?.host();
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let mut tls_config_builder = TlsConfig::builder();
tls_config_builder.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
});
let tls_config = tls_config_builder.build()?;
// Set up protocol configuration for prover.
let mut prover_config_builder = ProverConfig::builder();
prover_config_builder
.server_name(ServerName::Dns(server_domain.try_into()?))
.tls_config(tls_config)
.protocol_config(
ProtocolConfig::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?,
);
let prover_config = prover_config_builder.build()?;
// Create prover and connect to verifier.
//
// Perform the setup phase with the verifier.
let prover = Prover::new(prover_config)
.setup(verifier_socket.compat())
.await?;
// Connect to TLS Server.
let tls_client_socket = tokio::net::TcpStream::connect(server_addr).await?;
// Pass server connection into the prover.
let (mpc_tls_connection, prover_fut) = prover.connect(tls_client_socket.compat()).await?;
// Wrap the connection in a TokioIo compatibility layer to use it with hyper.
let mpc_tls_connection = TokioIo::new(mpc_tls_connection.compat());
// Spawn the Prover to run in the background.
let prover_task = tokio::spawn(prover_fut);
// MPC-TLS Handshake.
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(mpc_tls_connection).await?;
// Spawn the connection to run in the background.
tokio::spawn(connection);
// MPC-TLS: Send Request and wait for Response.
let request = Request::builder()
.uri(uri.clone())
.header("Host", server_domain)
.header("Connection", "close")
.header(header::AUTHORIZATION, "Bearer random_auth_token")
.method("GET")
.body(Empty::<Bytes>::new())?;
let response = request_sender.send_request(request).await?;
if response.status() != StatusCode::OK {
return Err(format!("MPC-TLS request failed with status {}", response.status()).into());
}
// Create proof for the Verifier.
let mut prover = prover_task.await??;
let transcript = prover.transcript().clone();
let mut prove_config_builder = ProveConfig::builder(&transcript);
// Reveal the DNS name.
prove_config_builder.server_identity();
let sent: &[u8] = transcript.sent();
let received: &[u8] = transcript.received();
let sent_len = sent.len();
let recv_len = received.len();
tracing::info!("Sent length: {}, Received length: {}", sent_len, recv_len);
// Reveal the entire HTTP request except for the authorization bearer token
reveal_request(sent, &mut prove_config_builder)?;
// Create hash commitment for the date of birth field from the response
let mut transcript_commitment_builder = TranscriptCommitConfig::builder(&transcript);
transcript_commitment_builder.default_kind(TranscriptCommitmentKind::Hash {
alg: HashAlgId::BLAKE3,
});
reveal_received(
received,
&mut prove_config_builder,
&mut transcript_commitment_builder,
)?;
let transcripts_commitment_config = transcript_commitment_builder.build()?;
prove_config_builder.transcript_commit(transcripts_commitment_config);
let prove_config = prove_config_builder.build()?;
// MPC-TLS prove
let prover_output = prover.prove(&prove_config).await?;
prover.close().await?;
// Prove birthdate is more than 18 years ago.
let received_commitments = received_commitments(&prover_output.transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("No received commitments found")?; // committed hash (of date of birth string)
let received_secrets = received_secrets(&prover_output.transcript_secrets);
let received_secret = received_secrets
.first()
.ok_or("No received secrets found")?; // hash blinder
let start_time = std::time::Instant::now();
let proof_input = prepare_zk_proof_input(received, received_commitment, received_secret)?;
let prepare_duration = start_time.elapsed();
tracing::info!("🔢 prepare_zk_proof_input took: {:?}", prepare_duration);
let start_time = std::time::Instant::now();
let proof_bundle = generate_zk_proof(&proof_input)?;
let generate_duration = start_time.elapsed();
tracing::info!("⚡ generate_zk_proof took: {:?}", generate_duration);
// Sent zk proof bundle to verifier
let serialized_proof = bincode::serialize(&proof_bundle)?;
verifier_extra_socket.write_all(&serialized_proof).await?;
verifier_extra_socket.shutdown().await?;
Ok(())
}
// Reveal everything from the request, except for the authorization token.
fn reveal_request(
request: &[u8],
builder: &mut ProveConfigBuilder<'_>,
) -> Result<(), Box<dyn std::error::Error>> {
let reqs = Requests::new_from_slice(request).collect::<Result<Vec<_>, _>>()?;
let req = reqs.first().ok_or("No requests found")?;
if req.request.method.as_str() != "GET" {
return Err(format!("Expected GET method, found {}", req.request.method.as_str()).into());
}
let authorization_header = req
.headers_with_name(header::AUTHORIZATION.as_str())
.next()
.ok_or("Authorization header not found")?;
let start_pos = authorization_header
.span()
.indices()
.min()
.ok_or("Could not find authorization header start position")?
+ header::AUTHORIZATION.as_str().len()
+ 2;
let end_pos =
start_pos + authorization_header.span().len() - header::AUTHORIZATION.as_str().len() - 2;
builder.reveal_sent(&(0..start_pos))?;
builder.reveal_sent(&(end_pos..request.len()))?;
Ok(())
}
fn reveal_received(
received: &[u8],
builder: &mut ProveConfigBuilder<'_>,
transcript_commitment_builder: &mut TranscriptCommitConfigBuilder,
) -> Result<(), Box<dyn std::error::Error>> {
let resp = Responses::new_from_slice(received).collect::<Result<Vec<_>, _>>()?;
let response = resp.first().ok_or("No responses found")?;
let body = response.body.as_ref().ok_or("Response body not found")?;
let BodyContent::Json(json) = &body.content else {
return Err("Expected JSON body content".into());
};
// reveal tax year
let tax_year = json
.get("tax_year")
.ok_or("tax_year field not found in JSON")?;
let start_pos = tax_year
.span()
.indices()
.min()
.ok_or("Could not find tax_year start position")?
- 11;
let end_pos = tax_year
.span()
.indices()
.max()
.ok_or("Could not find tax_year end position")?
+ 1;
builder.reveal_recv(&(start_pos..end_pos))?;
// commit to hash of date of birth
let dob = json
.get("taxpayer.date_of_birth")
.ok_or("taxpayer.date_of_birth field not found in JSON")?;
transcript_commitment_builder.commit_recv(dob.span())?;
Ok(())
}
// extract secret from prover output
fn received_secrets(transcript_secrets: &[TranscriptSecret]) -> Vec<&PlaintextHashSecret> {
transcript_secrets
.iter()
.filter_map(|secret| match secret {
TranscriptSecret::Hash(hash) if hash.direction == Direction::Received => Some(hash),
_ => None,
})
.collect()
}
#[derive(Debug)]
pub struct ZKProofInput {
dob: Vec<u8>,
proof_date: NaiveDate,
blinder: Vec<u8>,
committed_hash: Vec<u8>,
}
// Verify that the blinded, committed hash is correct
fn prepare_zk_proof_input(
received: &[u8],
received_commitment: &PlaintextHash,
received_secret: &PlaintextHashSecret,
) -> Result<ZKProofInput, Box<dyn std::error::Error>> {
assert_eq!(received_commitment.direction, Direction::Received);
assert_eq!(received_commitment.hash.alg, HashAlgId::BLAKE3);
let hash = &received_commitment.hash;
let dob_start = received_commitment
.idx
.min()
.ok_or("No start index for DOB")?;
let dob_end = received_commitment
.idx
.end()
.ok_or("No end index for DOB")?;
let dob = received[dob_start..dob_end].to_vec();
let blinder = received_secret.blinder.as_bytes().to_vec();
let committed_hash = hash.value.as_bytes().to_vec();
let proof_date = Local::now().date_naive();
assert_eq!(received_secret.direction, Direction::Received);
assert_eq!(received_secret.alg, HashAlgId::BLAKE3);
let mut hasher = blake3::Hasher::new();
hasher.update(&dob);
hasher.update(&blinder);
let computed_hash = hasher.finalize().as_bytes().to_vec();
if committed_hash != computed_hash.as_slice() {
return Err("Computed hash does not match committed hash".into());
}
Ok(ZKProofInput {
dob,
proof_date,
committed_hash,
blinder,
})
}
fn generate_zk_proof(
proof_input: &ZKProofInput,
) -> Result<ZKProofBundle, Box<dyn std::error::Error>> {
tracing::info!("🔒 Generating ZK proof with Noir...");
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
// 1. Load bytecode from program.json
let json: Value = serde_json::from_str(PROGRAM_JSON)?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("bytecode field not found in program.json")?;
let mut inputs: Vec<String> = vec![];
inputs.push(proof_input.proof_date.day().to_string());
inputs.push(proof_input.proof_date.month().to_string());
inputs.push(proof_input.proof_date.year().to_string());
inputs.extend(proof_input.committed_hash.iter().map(|b| b.to_string()));
inputs.extend(proof_input.dob.iter().map(|b| b.to_string()));
inputs.extend(proof_input.blinder.iter().map(|b| b.to_string()));
let proof_date = proof_input.proof_date.to_string();
tracing::info!(
"Public inputs : Proof date ({}) and committed hash ({})",
proof_date,
hex::encode(&proof_input.committed_hash)
);
tracing::info!(
"Private inputs: Blinder ({}) and Date of Birth ({})",
hex::encode(&proof_input.blinder),
String::from_utf8_lossy(&proof_input.dob)
);
tracing::debug!("Witness inputs {:?}", inputs);
let input_refs: Vec<&str> = inputs.iter().map(String::as_str).collect();
let witness = from_vec_str_to_witness_map(input_refs)?;
// Setup SRS
setup_srs_from_bytecode(bytecode, None, false)?;
// Verification key
let vk = get_ultra_honk_verification_key(bytecode, false)?;
// Generate proof
let proof = prove_ultra_honk(bytecode, witness.clone(), vk.clone(), false)?;
tracing::info!("✅ Proof generated ({} bytes)", proof.len());
let proof_bundle = ZKProofBundle { vk, proof };
Ok(proof_bundle)
}

View File

@@ -1,21 +0,0 @@
use serde::{Deserialize, Serialize};
use tlsn::transcript::{hash::PlaintextHash, Direction, TranscriptCommitment};
#[derive(Serialize, Deserialize, Debug)]
pub struct ZKProofBundle {
pub vk: Vec<u8>,
pub proof: Vec<u8>,
}
// extract commitment from prover output
pub fn received_commitments(
transcript_commitments: &[TranscriptCommitment],
) -> Vec<&PlaintextHash> {
transcript_commitments
.iter()
.filter_map(|commitment| match commitment {
TranscriptCommitment::Hash(hash) if hash.direction == Direction::Received => Some(hash),
_ => None,
})
.collect()
}

View File

@@ -1,184 +0,0 @@
use crate::types::received_commitments;
use super::types::ZKProofBundle;
use chrono::{Local, NaiveDate};
use noir::barretenberg::verify::{get_ultra_honk_verification_key, verify_ultra_honk};
use serde_json::Value;
use tls_server_fixture::CA_CERT_DER;
use tlsn::{
config::{CertificateDer, ProtocolConfigValidator, RootCertStore},
connection::ServerName,
hash::HashAlgId,
transcript::{Direction, PartialTranscript},
verifier::{Verifier, VerifierConfig, VerifierOutput, VerifyConfig},
};
use tlsn_examples::{MAX_RECV_DATA, MAX_SENT_DATA};
use tlsn_server_fixture_certs::SERVER_DOMAIN;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::instrument;
#[instrument(skip(socket, extra_socket))]
pub async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
socket: T,
mut extra_socket: T,
) -> Result<PartialTranscript, Box<dyn std::error::Error>> {
// Set up Verifier.
let config_validator = ProtocolConfigValidator::builder()
.max_sent_data(MAX_SENT_DATA)
.max_recv_data(MAX_RECV_DATA)
.build()?;
// Create a root certificate store with the server-fixture's self-signed
// certificate. This is only required for offline testing with the
// server-fixture.
let verifier_config = VerifierConfig::builder()
.root_store(RootCertStore {
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
})
.protocol_config_validator(config_validator)
.build()?;
let verifier = Verifier::new(verifier_config);
// Receive authenticated data.
let VerifierOutput {
server_name,
transcript,
transcript_commitments,
..
} = verifier
.verify(socket.compat(), &VerifyConfig::default())
.await?;
let server_name = server_name.ok_or("Prover should have revealed server name")?;
let transcript = transcript.ok_or("Prover should have revealed transcript data")?;
// Create hash commitment for the date of birth field from the response
let sent = transcript.sent_unsafe().to_vec();
let sent_data = String::from_utf8(sent.clone())
.map_err(|e| format!("Verifier expected valid UTF-8 sent data: {}", e))?;
if !sent_data.contains(SERVER_DOMAIN) {
return Err(format!(
"Verification failed: Expected host {} not found in sent data",
SERVER_DOMAIN
)
.into());
}
// Check received data.
let received_commitments = received_commitments(&transcript_commitments);
let received_commitment = received_commitments
.first()
.ok_or("Missing received hash commitment")?;
assert!(received_commitment.direction == Direction::Received);
assert!(received_commitment.hash.alg == HashAlgId::BLAKE3);
let committed_hash = &received_commitment.hash;
// Check Session info: server name.
let ServerName::Dns(server_name) = server_name;
if server_name.as_str() != SERVER_DOMAIN {
return Err(format!(
"Server name mismatch: expected {}, got {}",
SERVER_DOMAIN,
server_name.as_str()
)
.into());
}
// Receive ZKProof information from prover
let mut buf = Vec::new();
extra_socket.read_to_end(&mut buf).await?;
if buf.is_empty() {
return Err("No ZK proof data received from prover".into());
}
let msg: ZKProofBundle = bincode::deserialize(&buf)
.map_err(|e| format!("Failed to deserialize ZK proof bundle: {}", e))?;
// Verify zk proof
const PROGRAM_JSON: &str = include_str!("./noir/target/noir.json");
let json: Value = serde_json::from_str(PROGRAM_JSON)
.map_err(|e| format!("Failed to parse Noir circuit: {}", e))?;
let bytecode = json["bytecode"]
.as_str()
.ok_or("Bytecode field missing in noir.json")?;
let vk = get_ultra_honk_verification_key(bytecode, false)
.map_err(|e| format!("Failed to get verification key: {}", e))?;
if vk != msg.vk {
return Err("Verification key mismatch between computed and provided by prover".into());
}
let proof = msg.proof.clone();
// Validate proof has enough data.
// The proof should start with the public inputs:
// * We expect at least 3 * 32 bytes for the three date fields (day, month,
// year)
// * and 32*32 bytes for the hash
let min_bytes = (32 + 3) * 32;
if proof.len() < min_bytes {
return Err(format!(
"Proof too short: expected at least {} bytes, got {}",
min_bytes,
proof.len()
)
.into());
}
// Check that the proof date is correctly included in the proof
let proof_date_day: u32 = u32::from_be_bytes(proof[28..32].try_into()?);
let proof_date_month: u32 = u32::from_be_bytes(proof[60..64].try_into()?);
let proof_date_year: i32 = i32::from_be_bytes(proof[92..96].try_into()?);
let proof_date_from_proof =
NaiveDate::from_ymd_opt(proof_date_year, proof_date_month, proof_date_day)
.ok_or("Invalid proof date in proof")?;
let today = Local::now().date_naive();
if (today - proof_date_from_proof).num_days() < 0 {
return Err(format!(
"The proof date can only be today or in the past: provided {}, today {}",
proof_date_from_proof, today
)
.into());
}
// Check that the committed hash in the proof matches the hash from the
// commitment
let committed_hash_in_proof: Vec<u8> = proof
.chunks(32)
.skip(3) // skip the first 3 chunks
.take(32)
.map(|chunk| *chunk.last().unwrap_or(&0))
.collect();
let expected_hash = committed_hash.value.as_bytes().to_vec();
if committed_hash_in_proof != expected_hash {
tracing::error!(
"❌ The hash in the proof does not match the committed hash in MPC-TLS: {} != {}",
hex::encode(&committed_hash_in_proof),
hex::encode(&expected_hash)
);
return Err("Hash in proof does not match committed hash in MPC-TLS".into());
}
tracing::info!(
"✅ The hash in the proof matches the committed hash in MPC-TLS ({})",
hex::encode(&expected_hash)
);
// Finally verify the proof
let is_valid = verify_ultra_honk(msg.proof, msg.vk)
.map_err(|e| format!("ZKProof Verification failed: {}", e))?;
if !is_valid {
tracing::error!("❌ Age verification ZKProof failed to verify");
return Err("Age verification ZKProof failed to verify".into());
}
tracing::info!("✅ Age verification ZKProof successfully verified");
Ok(transcript)
}

View File

@@ -93,7 +93,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -107,7 +107,7 @@ async fn prover(provider: &IoProvider) {
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -72,5 +72,4 @@ pub(crate) struct ServerFinishedVd {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)]
pub(crate) struct CloseConnection;

View File

@@ -193,7 +193,7 @@ where
};
// Divide by block length and round up.
let block_count = input.len() / 16 + !input.len().is_multiple_of(16) as usize;
let block_count = input.len() / 16 + (input.len() % 16 != 0) as usize;
if block_count > MAX_POWER {
return Err(ErrorRepr::InputLength {
@@ -282,11 +282,11 @@ fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + !aad.len().is_multiple_of(16) as usize;
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + !ciphertext.len().is_multiple_of(16) as usize;
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);

View File

@@ -1,37 +0,0 @@
{
"tax_year": 2024,
"taxpayer": {
"idnr": "12345678901",
"first_name": "Max",
"last_name": "Mustermann",
"date_of_birth": "1985-03-12",
"address": {
"street": "Musterstraße 1",
"postal_code": "10115",
"city": "Berlin"
}
},
"income": {
"employment_income": 54200.00,
"other_income": 1200.00,
"capital_gains": 350.00
},
"deductions": {
"pension_insurance": 4200.00,
"health_insurance": 3600.00,
"donations": 500.00,
"work_related_expenses": 1100.00
},
"assessment": {
"taxable_income": 49200.00,
"income_tax": 9156.00,
"solidarity_surcharge": 503.58,
"total_tax": 9659.58,
"prepaid_tax": 9500.00,
"refund": 159.58
},
"submission": {
"submitted_at": "2025-03-01T14:22:30Z",
"submitted_by": "ElsterOnline-Portal"
}
}

View File

@@ -47,7 +47,6 @@ fn app(state: AppState) -> Router {
.route("/formats/json", get(json))
.route("/formats/html", get(html))
.route("/protected", get(protected_route))
.route("/elster", get(elster_route))
.layer(TraceLayer::new_for_http())
.with_state(Arc::new(Mutex::new(state)))
}
@@ -197,12 +196,6 @@ async fn protected_route(_: AuthenticatedUser) -> Result<Json<Value>, StatusCode
get_json_value(include_str!("data/protected_data.json"))
}
async fn elster_route(_: AuthenticatedUser) -> Result<Json<Value>, StatusCode> {
info!("Handling /elster");
get_json_value(include_str!("data/elster.json"))
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -40,6 +40,9 @@ mpz-ot = { workspace = true }
mpz-vm-core = { workspace = true }
mpz-zk = { workspace = true }
aes = { workspace = true }
cipher-crypto = { workspace = true }
ctr = { workspace = true }
derive_builder = { workspace = true }
futures = { workspace = true }
opaque-debug = { workspace = true }
@@ -57,6 +60,8 @@ rangeset = { workspace = true }
webpki-roots = { workspace = true }
[dev-dependencies]
lipsum = { workspace = true }
sha2 = { workspace = true }
rstest = { workspace = true }
tlsn-server-fixture = { workspace = true }
tlsn-server-fixture-certs = { workspace = true }
@@ -65,3 +70,5 @@ tokio-util = { workspace = true, features = ["compat"] }
hyper = { workspace = true, features = ["client"] }
http-body-util = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tlsn-core = { workspace = true, features = ["fixtures"] }
mpz-ot = { workspace = true, features = ["ideal"] }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,708 @@
//! Authentication of the transcript plaintext and creation of the transcript
//! references.
use std::ops::Range;
use mpz_core::bitvec::BitVec;
use mpz_memory_core::{DecodeError, DecodeFutureTyped, MemoryExt, binary::Binary};
use mpz_vm_core::Vm;
use rangeset::{Disjoint, RangeSet, Union, UnionMut};
use tlsn_core::{
hash::HashAlgId,
transcript::{ContentType, Direction, PartialTranscript, Record, TlsTranscript},
};
use crate::{
Role,
commit::transcript::TranscriptRefs,
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
};
/// Transcript Authenticator.
pub(crate) struct Authenticator {
encoding: Index,
hash: Index,
decoding: Index,
proving: Index,
}
impl Authenticator {
/// Creates a new authenticator.
///
/// # Arguments
///
/// * `encoding` - Ranges for encoding commitments.
/// * `hash` - Ranges for hash commitments.
/// * `partial` - The partial transcript.
pub(crate) fn new<'a>(
encoding: impl Iterator<Item = &'a (Direction, RangeSet<usize>)>,
hash: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
partial: Option<&PartialTranscript>,
) -> Self {
// Compute encoding index.
let mut encoding_sent = RangeSet::default();
let mut encoding_recv = RangeSet::default();
for (d, idx) in encoding {
match d {
Direction::Sent => encoding_sent.union_mut(idx),
Direction::Received => encoding_recv.union_mut(idx),
}
}
let encoding = Index::new(encoding_sent, encoding_recv);
// Compute hash index.
let mut hash_sent = RangeSet::default();
let mut hash_recv = RangeSet::default();
for (d, idx, _) in hash {
match d {
Direction::Sent => hash_sent.union_mut(idx),
Direction::Received => hash_recv.union_mut(idx),
}
}
let hash = Index {
sent: hash_sent,
recv: hash_recv,
};
// Compute decoding index.
let mut decoding_sent = RangeSet::default();
let mut decoding_recv = RangeSet::default();
if let Some(partial) = partial {
decoding_sent.union_mut(partial.sent_authed());
decoding_recv.union_mut(partial.received_authed());
}
let decoding = Index::new(decoding_sent, decoding_recv);
// Compute proving index.
let mut proving_sent = RangeSet::default();
let mut proving_recv = RangeSet::default();
proving_sent.union_mut(decoding.sent());
proving_sent.union_mut(encoding.sent());
proving_sent.union_mut(hash.sent());
proving_recv.union_mut(decoding.recv());
proving_recv.union_mut(encoding.recv());
proving_recv.union_mut(hash.recv());
let proving = Index::new(proving_sent, proving_recv);
Self {
encoding,
hash,
decoding,
proving,
}
}
/// Authenticates the sent plaintext, returning a proof of encryption and
/// writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_sent` - ZK AES Cipher for sent traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_sent(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_sent: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let missing_index = transcript_refs.compute_missing(Direction::Sent, self.proving.sent());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_sent,
Direction::Sent,
sent,
transcript_refs,
missing_index,
)
}
/// Authenticates the received plaintext, returning a proof of encryption
/// and writes the plaintext VM references to the transcript references.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes_recv` - ZK AES Cipher for received traffic.
/// * `transcript` - The TLS transcript.
/// * `transcript_refs` - The transcript references.
pub(crate) fn auth_recv(
&mut self,
vm: &mut dyn Vm<Binary>,
zk_aes_recv: &mut ZkAesCtr,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) -> Result<RecordProof, AuthError> {
let decoding_recv = self.decoding.recv();
let fully_decoded = decoding_recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
// If we only have decoding ranges, and the parts we are going to decode will
// complete to the full received transcript, then we do not need to
// authenticate, because this will be done by
// `crate::commit::decode::verify_transcript`, as it uses the server write
// key and iv for verification.
if decoding_recv == self.proving.recv() && fully_decoded == full_range {
return Ok(RecordProof::default());
}
let missing_index =
transcript_refs.compute_missing(Direction::Received, self.proving.recv());
// If there is nothing new to prove, return early.
if missing_index == RangeSet::default() {
return Ok(RecordProof::default());
}
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
authenticate(
vm,
zk_aes_recv,
Direction::Received,
recv,
transcript_refs,
missing_index,
)
}
/// Returns the sent and received encoding ranges.
pub(crate) fn encoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.encoding.sent(), self.encoding.recv())
}
/// Returns the sent and received hash ranges.
pub(crate) fn hash(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.hash.sent(), self.hash.recv())
}
/// Returns the sent and received decoding ranges.
pub(crate) fn decoding(&self) -> (&RangeSet<usize>, &RangeSet<usize>) {
(self.decoding.sent(), self.decoding.recv())
}
}
/// Authenticates parts of the transcript in zk.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `zk_aes` - ZK AES Cipher.
/// * `direction` - The direction of the application data.
/// * `app_data` - The application data.
/// * `transcript_refs` - The transcript references.
/// * `missing_index` - The index which needs to be proven.
fn authenticate<'a>(
vm: &mut dyn Vm<Binary>,
zk_aes: &mut ZkAesCtr,
direction: Direction,
app_data: impl Iterator<Item = &'a Record>,
transcript_refs: &mut TranscriptRefs,
missing_index: RangeSet<usize>,
) -> Result<RecordProof, AuthError> {
let mut record_idx = Range::default();
let mut ciphertexts = Vec::new();
for record in app_data {
let record_len = record.ciphertext.len();
record_idx.end += record_len;
if missing_index.is_disjoint(&record_idx) {
record_idx.start += record_len;
continue;
}
let (plaintext_ref, ciphertext_ref) =
zk_aes.encrypt(vm, record.explicit_nonce.clone(), record.ciphertext.len())?;
if let Role::Prover = zk_aes.role() {
let Some(plaintext) = record.plaintext.clone() else {
return Err(AuthError(ErrorRepr::MissingPlainText));
};
vm.assign(plaintext_ref, plaintext).map_err(AuthError::vm)?;
}
vm.commit(plaintext_ref).map_err(AuthError::vm)?;
let ciphertext = vm.decode(ciphertext_ref).map_err(AuthError::vm)?;
transcript_refs.add(direction, &record_idx, plaintext_ref);
ciphertexts.push((ciphertext, record.ciphertext.clone()));
record_idx.start += record_len;
}
let proof = RecordProof { ciphertexts };
Ok(proof)
}
#[derive(Debug, Clone, Default)]
struct Index {
sent: RangeSet<usize>,
recv: RangeSet<usize>,
}
impl Index {
fn new(sent: RangeSet<usize>, recv: RangeSet<usize>) -> Self {
Self { sent, recv }
}
fn sent(&self) -> &RangeSet<usize> {
&self.sent
}
fn recv(&self) -> &RangeSet<usize> {
&self.recv
}
}
/// Proof of encryption.
#[derive(Debug, Default)]
#[must_use]
#[allow(clippy::type_complexity)]
pub(crate) struct RecordProof {
ciphertexts: Vec<(DecodeFutureTyped<BitVec, Vec<u8>>, Vec<u8>)>,
}
impl RecordProof {
/// Verifies the proof.
pub(crate) fn verify(self) -> Result<(), AuthError> {
let Self { ciphertexts } = self;
for (mut ciphertext, expected) in ciphertexts {
let ciphertext = ciphertext
.try_recv()
.map_err(AuthError::vm)?
.ok_or(AuthError(ErrorRepr::MissingDecoding))?;
if ciphertext != expected {
return Err(AuthError(ErrorRepr::InvalidCiphertext));
}
}
Ok(())
}
}
/// Error for [`Authenticator`].
#[derive(Debug, thiserror::Error)]
#[error("transcript authentication error: {0}")]
pub(crate) struct AuthError(#[source] ErrorRepr);
impl AuthError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("zk-aes error: {0}")]
ZkAes(ZkAesCtrError),
#[error("decode error: {0}")]
Decode(DecodeError),
#[error("plaintext is missing in record")]
MissingPlainText,
#[error("decoded value is missing")]
MissingDecoding,
#[error("invalid ciphertext")]
InvalidCiphertext,
}
impl From<ZkAesCtrError> for AuthError {
fn from(value: ZkAesCtrError) -> Self {
Self(ErrorRepr::ZkAes(value))
}
}
impl From<DecodeError> for AuthError {
fn from(value: DecodeError) -> Self {
Self(ErrorRepr::Decode(value))
}
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
auth::{Authenticator, ErrorRepr},
transcript::TranscriptRefs,
},
zk_aes_ctr::ZkAesCtr,
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY, RECORD_SIZE},
hash::HashAlgId,
transcript::{ContentType, Direction, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_authenticator_sent(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_sent(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(600..1600));
prove_range.union_mut(&(800..2000));
prove_range.union_mut(&(2600..3700));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Sent), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Sent), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_recv(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, RECV_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, RECV_LEN).unwrap();
let _ = auth_prover
.auth_recv(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
let proof = auth_verifier
.auth_recv(
&mut verifier,
&mut zk_verifier,
&transcript,
&mut refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
proof.verify().unwrap();
let mut prove_range: RangeSet<usize> = RangeSet::default();
prove_range.union_mut(&(4000..4200));
prove_range.union_mut(&(5000..5800));
prove_range.union_mut(&(6800..RECV_LEN));
let mut expected_ranges = RangeSet::default();
for r in prove_range.iter_ranges() {
let floor = r.start / RECORD_SIZE;
let ceil = r.end.div_ceil(RECORD_SIZE);
let expected = floor * RECORD_SIZE..ceil * RECORD_SIZE;
expected_ranges.union_mut(&expected);
}
assert_eq!(refs_prover.index(Direction::Received), expected_ranges);
assert_eq!(refs_verifier.index(Direction::Received), expected_ranges);
}
#[rstest]
#[tokio::test]
async fn test_authenticator_sent_verify_fail(
encoding: Vec<(Direction, RangeSet<usize>)>,
hashes: Vec<(Direction, RangeSet<usize>, HashAlgId)>,
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let (sent_decdoding, recv_decdoding) = decoding;
let partial = transcript
.to_transcript()
.unwrap()
.to_partial(sent_decdoding, recv_decdoding);
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut refs_prover = transcript_refs.clone();
let mut refs_verifier = transcript_refs;
let (key, iv) = keys(&mut prover, KEY, IV, Role::Prover);
let mut auth_prover = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_prover = ZkAesCtr::new(Role::Prover);
zk_prover.set_key(key, iv);
zk_prover.alloc(&mut prover, SENT_LEN).unwrap();
let (key, iv) = keys(&mut verifier, KEY, IV, Role::Verifier);
let mut auth_verifier = Authenticator::new(encoding.iter(), hashes.iter(), Some(&partial));
let mut zk_verifier = ZkAesCtr::new(Role::Verifier);
zk_verifier.set_key(key, iv);
zk_verifier.alloc(&mut verifier, SENT_LEN).unwrap();
let _ = auth_prover
.auth_sent(&mut prover, &mut zk_prover, &transcript, &mut refs_prover)
.unwrap();
// Forge verifier transcript to check if verify fails.
// Use an index which is part of the proving range.
let forged = forged();
let proof = auth_verifier
.auth_sent(&mut verifier, &mut zk_verifier, &forged, &mut refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let err = proof.verify().unwrap_err();
assert!(matches!(err.0, ErrorRepr::InvalidCiphertext));
}
fn keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let sent = 600..1600;
let recv = 4000..4200;
(sent.into(), recv.into())
}
#[fixture]
fn encoding() -> Vec<(Direction, RangeSet<usize>)> {
let sent = 800..2000;
let recv = 5000..5800;
let encoding = vec![
(Direction::Sent, sent.into()),
(Direction::Received, recv.into()),
];
encoding
}
#[fixture]
fn hashes() -> Vec<(Direction, RangeSet<usize>, HashAlgId)> {
let sent = 2600..3700;
let recv = 6800..RECV_LEN;
let alg = HashAlgId::SHA256;
let hashes = vec![
(Direction::Sent, sent.into(), alg),
(Direction::Received, recv.into(), alg),
];
hashes
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 610;
let mut sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
sent[WRONG_BYTE_INDEX] = sent[WRONG_BYTE_INDEX].wrapping_add(1);
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
TranscriptRefs::new(sent_len, recv_len)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,615 @@
//! Selective disclosure.
use mpz_memory_core::{
Array, MemoryExt,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{Intersection, RangeSet, Subset, Union};
use tlsn_core::transcript::{ContentType, Direction, PartialTranscript, TlsTranscript};
use crate::commit::TranscriptRefs;
/// Decodes parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `transcript_refs` - The transcript references.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
transcript_refs: &mut TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding_ranges;
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for slice in sent_refs.into_iter() {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Sent, sent);
// If possible use server write key for decoding.
let fully_decoded = recv.union(&transcript_refs.decoded(Direction::Received));
let full_range = 0..transcript_refs.max_len(Direction::Received);
if fully_decoded == full_range {
// Drop the future, we don't need it.
drop(vm.decode(key).map_err(DecodeError::vm)?);
drop(vm.decode(iv).map_err(DecodeError::vm)?);
transcript_refs.mark_decoded(Direction::Received, &full_range.into());
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for slice in recv_refs {
// Drop the future, we don't need it.
drop(vm.decode(slice).map_err(DecodeError::vm));
}
transcript_refs.mark_decoded(Direction::Received, recv);
}
Ok(())
}
/// Verifies parts of the transcript.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `key` - The server write key.
/// * `iv` - The server write iv.
/// * `decoding_ranges` - The decoding ranges.
/// * `partial` - The partial transcript.
/// * `transcript_refs` - The transcript references.
/// * `transcript` - The TLS transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
key: Array<U8, 16>,
iv: Array<U8, 4>,
decoding_ranges: (&RangeSet<usize>, &RangeSet<usize>),
partial: Option<&PartialTranscript>,
transcript_refs: &mut TranscriptRefs,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let (sent, recv) = decoding_ranges;
let mut authenticated_data = Vec::new();
// Add sent transcript parts.
let sent_refs = transcript_refs.get(Direction::Sent, sent);
for data in sent_refs.into_iter() {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
// Add received transcript parts, if possible using key and iv.
if let (Some(key), Some(iv)) = (
vm.get(key).map_err(DecodeError::vm)?,
vm.get(iv).map_err(DecodeError::vm)?,
) {
let plaintext = verify_with_keys(key, iv, recv, transcript)?;
authenticated_data.extend_from_slice(&plaintext);
} else {
let recv_refs = transcript_refs.get(Direction::Received, recv);
for data in recv_refs {
let plaintext = vm
.get(data)
.map_err(DecodeError::vm)?
.ok_or(DecodeError(ErrorRepr::MissingPlaintext))?;
authenticated_data.extend_from_slice(&plaintext);
}
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in sent.iter_ranges() {
purported_data.extend_from_slice(&partial.sent_unsafe()[range]);
}
for range in recv.iter_ranges() {
purported_data.extend_from_slice(&partial.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(DecodeError(ErrorRepr::InconsistentTranscript));
}
Ok(())
}
/// Checks the transcript length.
///
/// # Arguments
///
/// * `partial` - The partial transcript.
/// * `transcript` - The TLS transcript.
pub(crate) fn check_transcript_length(
partial: Option<&PartialTranscript>,
transcript: &TlsTranscript,
) -> Result<(), DecodeError> {
let Some(partial) = partial else {
return Err(DecodeError(ErrorRepr::MissingPartialTranscript));
};
let sent_len: usize = transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len: usize = transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
// Check ranges.
if partial.len_sent() != sent_len || partial.len_received() != recv_len {
return Err(DecodeError(ErrorRepr::VerifyTranscriptLength));
}
Ok(())
}
fn verify_with_keys(
key: [u8; 16],
iv: [u8; 4],
decoding_ranges: &RangeSet<usize>,
transcript: &TlsTranscript,
) -> Result<Vec<u8>, DecodeError> {
let mut plaintexts = Vec::with_capacity(decoding_ranges.len());
let mut position = 0_usize;
let recv_data = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv_data {
let current = position..position + record.ciphertext.len();
if !current.is_subset(decoding_ranges) {
position += record.ciphertext.len();
continue;
}
let nonce = record
.explicit_nonce
.clone()
.try_into()
.expect("explicit nonce should be 8 bytes");
let plaintext = aes_apply_keystream(key, iv, nonce, &record.ciphertext);
let record_decoding_range = decoding_ranges.intersection(&current);
for r in record_decoding_range.iter_ranges() {
let shifted = r.start - position..r.end - position;
plaintexts.extend_from_slice(&plaintext[shifted]);
}
position += record.ciphertext.len()
}
Ok(plaintexts)
}
fn aes_apply_keystream(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], msg: &[u8]) -> Vec<u8> {
use aes::Aes128;
use cipher_crypto::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr32BE;
let start_ctr = 2;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(&iv);
full_iv[4..12].copy_from_slice(&explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(&key.into(), &full_iv.into());
let mut out = msg.to_vec();
cipher
.try_seek(start_ctr * 16)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut out);
out
}
/// A decoding error.
#[derive(Debug, thiserror::Error)]
#[error("decode error: {0}")]
pub(crate) struct DecodeError(#[source] ErrorRepr);
impl DecodeError {
fn vm<E>(err: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Self(ErrorRepr::Vm(err.into()))
}
}
#[derive(Debug, thiserror::Error)]
enum ErrorRepr {
#[error("vm error: {0}")]
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("missing partial transcript")]
MissingPartialTranscript,
#[error("length of partial transcript does not match expected length")]
VerifyTranscriptLength,
#[error("provided transcript does not match exptected")]
InconsistentTranscript,
#[error("trying to get plaintext, but it is missing")]
MissingPlaintext,
}
#[cfg(test)]
mod tests {
use crate::{
Role,
commit::{
TranscriptRefs,
decode::{DecodeError, ErrorRepr, decode_transcript, verify_transcript},
},
};
use lipsum::{LIBER_PRIMUS, lipsum};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
Array, MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
fixtures::transcript::{IV, KEY},
transcript::{ContentType, Direction, PartialTranscript, TlsTranscript},
};
#[rstest]
#[tokio::test]
async fn test_decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding.clone());
decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_fail(
decoding: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding.clone());
let err = decode(decoding, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
#[rstest]
#[tokio::test]
async fn test_decode_all(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&transcript, decoding_full.clone());
decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_decode_all_fail(
decoding_full: (RangeSet<usize>, RangeSet<usize>),
forged: TlsTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) {
let partial = partial(&forged, decoding_full.clone());
let err = decode(decoding_full, partial, transcript, transcript_refs)
.await
.unwrap_err();
assert!(matches!(err.0, ErrorRepr::InconsistentTranscript));
}
async fn decode(
decoding: (RangeSet<usize>, RangeSet<usize>),
partial: PartialTranscript,
transcript: TlsTranscript,
transcript_refs: TranscriptRefs,
) -> Result<(), DecodeError> {
let (sent, recv) = decoding;
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut transcript_refs_verifier = transcript_refs.clone();
let mut transcript_refs_prover = transcript_refs;
let key: [u8; 16] = KEY;
let iv: [u8; 4] = IV;
let (key_prover, iv_prover) = assign_keys(&mut prover, key, iv, Role::Prover);
let (key_verifier, iv_verifier) = assign_keys(&mut verifier, key, iv, Role::Verifier);
assign_transcript(
&mut prover,
Role::Prover,
&transcript,
&mut transcript_refs_prover,
);
assign_transcript(
&mut verifier,
Role::Verifier,
&transcript,
&mut transcript_refs_verifier,
);
decode_transcript(
&mut prover,
key_prover,
iv_prover,
(&sent, &recv),
&mut transcript_refs_prover,
)
.unwrap();
decode_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
&mut transcript_refs_verifier,
)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v),
)
.unwrap();
verify_transcript(
&mut verifier,
key_verifier,
iv_verifier,
(&sent, &recv),
Some(&partial),
&mut transcript_refs_verifier,
&transcript,
)
}
fn assign_keys(
vm: &mut dyn Vm<Binary>,
key_value: [u8; 16],
iv_value: [u8; 4],
role: Role,
) -> (Array<U8, 16>, Array<U8, 4>) {
let key: Array<U8, 16> = vm.alloc().unwrap();
let iv: Array<U8, 4> = vm.alloc().unwrap();
if let Role::Prover = role {
vm.mark_private(key).unwrap();
vm.mark_private(iv).unwrap();
vm.assign(key, key_value).unwrap();
vm.assign(iv, iv_value).unwrap();
} else {
vm.mark_blind(key).unwrap();
vm.mark_blind(iv).unwrap();
}
vm.commit(key).unwrap();
vm.commit(iv).unwrap();
(key, iv)
}
fn assign_transcript(
vm: &mut dyn Vm<Binary>,
role: Role,
transcript: &TlsTranscript,
transcript_refs: &mut TranscriptRefs,
) {
let mut pos = 0_usize;
let sent = transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in sent {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Sent, &index, plaintext_ref);
pos += record.ciphertext.len();
}
pos = 0;
let recv = transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData);
for record in recv {
let len = record.ciphertext.len();
let cipher_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
vm.mark_public(cipher_ref).unwrap();
vm.assign(cipher_ref, record.ciphertext.clone()).unwrap();
vm.commit(cipher_ref).unwrap();
let plaintext_ref: Vector<U8> = vm.alloc_vec(len).unwrap();
if let Role::Prover = role {
vm.mark_private(plaintext_ref).unwrap();
vm.assign(plaintext_ref, record.plaintext.clone().unwrap())
.unwrap();
} else {
vm.mark_blind(plaintext_ref).unwrap();
}
vm.commit(plaintext_ref).unwrap();
let index = pos..pos + record.ciphertext.len();
transcript_refs.add(Direction::Received, &index, plaintext_ref);
pos += record.ciphertext.len();
}
}
fn partial(
transcript: &TlsTranscript,
decoding: (RangeSet<usize>, RangeSet<usize>),
) -> PartialTranscript {
let (sent, recv) = decoding;
transcript.to_transcript().unwrap().to_partial(sent, recv)
}
#[fixture]
fn decoding() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
sent.union_mut(&(600..1100));
sent.union_mut(&(3450..4000));
recv.union_mut(&(2000..3000));
recv.union_mut(&(4800..4900));
recv.union_mut(&(6000..7000));
(sent, recv)
}
#[fixture]
fn decoding_full(transcript: TlsTranscript) -> (RangeSet<usize>, RangeSet<usize>) {
let transcript = transcript.to_transcript().unwrap();
let (len_sent, len_recv) = transcript.len();
let sent = (0..len_sent).into();
let recv = (0..len_recv).into();
(sent, recv)
}
#[fixture]
fn transcript() -> TlsTranscript {
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn forged() -> TlsTranscript {
const WRONG_BYTE_INDEX: usize = 2200;
let sent = LIBER_PRIMUS.as_bytes()[..SENT_LEN].to_vec();
let mut recv = lipsum(RECV_LEN).into_bytes();
recv.truncate(RECV_LEN);
recv[WRONG_BYTE_INDEX] = recv[WRONG_BYTE_INDEX].wrapping_add(1);
tlsn_core::fixtures::transcript::transcript_fixture(&sent, &recv)
}
#[fixture]
fn transcript_refs(transcript: TlsTranscript) -> TranscriptRefs {
let sent_len = transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
TranscriptRefs::new(sent_len, recv_len)
}
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
const SENT_LEN: usize = 4096;
const RECV_LEN: usize = 8192;
}

View File

@@ -0,0 +1,530 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_memory_core::{
MemoryType, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::Vm;
use rangeset::{RangeSet, Subset, UnionMut};
use serde::{Deserialize, Serialize};
use tlsn_core::{
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256, TypedHash},
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingProvider, EncodingProviderError, EncodingTree,
EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
pub(crate) const ENCODING_SIZE: usize = 128;
pub(crate) trait EncodingVm<T: MemoryType>: EncodingMemory<T> + Vm<T> {}
impl<T: MemoryType, U> EncodingVm<T> for U where U: EncodingMemory<T> + Vm<T> {}
pub(crate) trait EncodingMemory<T: MemoryType> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8>;
}
impl<T> EncodingMemory<Binary> for mpz_zk::Prover<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let macs = self.get_macs(v).expect("macs should be available");
encodings.extend(macs.iter().flat_map(|mac| mac.as_bytes()));
}
encodings
}
}
impl<T> EncodingMemory<Binary> for mpz_zk::Verifier<T> {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let len = values.iter().map(|v| v.len()).sum::<usize>() * ENCODING_SIZE;
let mut encodings = Vec::with_capacity(len);
for &v in values {
let keys = self.get_keys(v).expect("keys should be available");
encodings.extend(keys.iter().flat_map(|key| key.as_block().as_bytes()));
}
encodings
}
}
/// The encoding adjustments.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Encodings {
pub(crate) sent: Vec<u8>,
pub(crate) recv: Vec<u8>,
}
/// Creates encoding commitments.
#[derive(Debug)]
pub(crate) struct EncodingCreator {
hash_id: Option<HashAlgId>,
sent: RangeSet<usize>,
recv: RangeSet<usize>,
idxs: Vec<(Direction, RangeSet<usize>)>,
}
impl EncodingCreator {
/// Creates a new encoding creator.
///
/// # Arguments
///
/// * `hash_id` - The id of the hash algorithm.
/// * `idxs` - The indices for encoding commitments.
pub(crate) fn new(hash_id: Option<HashAlgId>, idxs: Vec<(Direction, RangeSet<usize>)>) -> Self {
let mut sent = RangeSet::default();
let mut recv = RangeSet::default();
for (direction, idx) in idxs.iter() {
for range in idx.iter_ranges() {
match direction {
Direction::Sent => sent.union_mut(&range),
Direction::Received => recv.union_mut(&range),
}
}
}
Self {
hash_id,
sent,
recv,
idxs,
}
}
/// Receives the encodings using the provided MACs from the encoding memory.
///
/// The MACs must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `encodings` - The encoding adjustments.
/// * `transcript_refs` - The transcript references.
pub(crate) fn receive(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
encodings: Encodings,
transcript_refs: &TranscriptRefs,
) -> Result<(TypedHash, EncodingTree), EncodingError> {
let Some(id) = self.hash_id else {
return Err(EncodingError(ErrorRepr::MissingHashId));
};
let hasher: &(dyn HashAlgorithm + Send + Sync) = match id {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(EncodingError(ErrorRepr::UnsupportedHashAlg(alg)));
}
};
let Encodings {
sent: mut sent_adjust,
recv: mut recv_adjust,
} = encodings;
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_adjust, &mut recv_adjust)?;
let provider = Provider::new(sent_adjust, &self.sent, recv_adjust, &self.recv);
let tree = EncodingTree::new(hasher, self.idxs.iter(), &provider)?;
let root = tree.root();
Ok((root, tree))
}
/// Transfers the encodings using the provided secret and keys from the
/// encoding memory.
///
/// The keys must be consistent with the global delta used in the encodings.
///
/// # Arguments
///
/// * `encoding_mem` - The encoding memory.
/// * `secret` - The encoder secret.
/// * `transcript_refs` - The transcript references.
pub(crate) fn transfer(
&self,
encoding_mem: &dyn EncodingMemory<Binary>,
secret: EncoderSecret,
transcript_refs: &TranscriptRefs,
) -> Result<Encodings, EncodingError> {
let encoder = new_encoder(&secret);
let mut sent_zero = Vec::with_capacity(self.sent.len() * ENCODING_SIZE);
let mut recv_zero = Vec::with_capacity(self.recv.len() * ENCODING_SIZE);
for range in self.sent.iter_ranges() {
encoder.encode_range(Direction::Sent, range, &mut sent_zero);
}
for range in self.recv.iter_ranges() {
encoder.encode_range(Direction::Received, range, &mut recv_zero);
}
let sent_refs = transcript_refs.get(Direction::Sent, &self.sent);
let sent = encoding_mem.get_encodings(&sent_refs);
let recv_refs = transcript_refs.get(Direction::Received, &self.recv);
let recv = encoding_mem.get_encodings(&recv_refs);
adjust(&sent, &recv, &mut sent_zero, &mut recv_zero)?;
let encodings = Encodings {
sent: sent_zero,
recv: recv_zero,
};
Ok(encodings)
}
}
/// Adjust encodings by transcript references.
///
/// # Arguments
///
/// * `sent` - The encodings for the sent bytes.
/// * `recv` - The encodings for the received bytes.
/// * `sent_adjust` - The adjustment bytes for the encodings of the sent bytes.
/// * `recv_adjust` - The adjustment bytes for the encodings of the received
/// bytes.
fn adjust(
sent: &[u8],
recv: &[u8],
sent_adjust: &mut [u8],
recv_adjust: &mut [u8],
) -> Result<(), EncodingError> {
assert_eq!(sent.len() % ENCODING_SIZE, 0);
assert_eq!(recv.len() % ENCODING_SIZE, 0);
if sent_adjust.len() != sent.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Sent,
expected: sent.len(),
got: sent_adjust.len(),
}
.into());
}
if recv_adjust.len() != recv.len() {
return Err(ErrorRepr::IncorrectAdjustCount {
direction: Direction::Received,
expected: recv.len(),
got: recv_adjust.len(),
}
.into());
}
sent_adjust
.iter_mut()
.zip(sent)
.for_each(|(adjust, enc)| *adjust ^= enc);
recv_adjust
.iter_mut()
.zip(recv)
.for_each(|(adjust, enc)| *adjust ^= enc);
Ok(())
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
sent_range: RangeSet<usize>,
recv: Vec<u8>,
recv_range: RangeSet<usize>,
}
impl Provider {
fn new(
sent: Vec<u8>,
sent_range: &RangeSet<usize>,
recv: Vec<u8>,
recv_range: &RangeSet<usize>,
) -> Self {
assert_eq!(
sent.len(),
sent_range.len() * ENCODING_SIZE,
"length of sent encodings and their index length do not match"
);
assert_eq!(
recv.len(),
recv_range.len() * ENCODING_SIZE,
"length of received encodings and their index length do not match"
);
Self {
sent,
sent_range: sent_range.clone(),
recv,
recv_range: recv_range.clone(),
}
}
fn adjust(
&self,
direction: Direction,
range: &Range<usize>,
) -> Result<Range<usize>, EncodingProviderError> {
let internal_range = match direction {
Direction::Sent => &self.sent_range,
Direction::Received => &self.recv_range,
};
if !range.is_subset(internal_range) {
return Err(EncodingProviderError);
}
let shift = internal_range
.iter()
.take_while(|&el| el < range.start)
.count();
let translated = Range {
start: shift,
end: shift + range.len(),
};
Ok(translated)
}
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let range = self.adjust(direction, &range)?;
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub(crate) struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("incorrect adjustment count for {direction}: expected {expected}, got {got}")]
IncorrectAdjustCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
#[error("missing hash id")]
MissingHashId,
#[error("unsupported hash algorithm for encoding commitment: {0}")]
UnsupportedHashAlg(HashAlgId),
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ops::Range;
use crate::commit::{
encoding::{ENCODING_SIZE, EncodingCreator, Encodings, Provider},
transcript::TranscriptRefs,
};
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_memory_core::{
FromRaw, Slice, ToRaw, Vector,
binary::{Binary, U8},
};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use tlsn_core::{
hash::{HashAlgId, HashProvider},
transcript::{
Direction,
encoding::{EncoderSecret, EncodingCommitment, EncodingProvider},
},
};
#[rstest]
fn test_encoding_adjust(
index: (RangeSet<usize>, RangeSet<usize>),
transcript_refs: TranscriptRefs,
encoding_idxs: Vec<(Direction, RangeSet<usize>)>,
) {
let creator = EncodingCreator::new(Some(HashAlgId::SHA256), encoding_idxs);
let mock_memory = MockEncodingMemory;
let delta = Delta::new(Block::ONES);
let seed = [1_u8; 32];
let secret = EncoderSecret::new(seed, delta.as_block().to_bytes());
let adjustments = creator
.transfer(&mock_memory, secret, &transcript_refs)
.unwrap();
let (root, tree) = creator
.receive(&mock_memory, adjustments, &transcript_refs)
.unwrap();
// Check correctness of encoding protocol.
let mut idxs = Vec::new();
let (sent_range, recv_range) = index;
idxs.push((Direction::Sent, sent_range.clone()));
idxs.push((Direction::Received, recv_range.clone()));
let commitment = EncodingCommitment { root, secret };
let proof = tree.proof(idxs.iter()).unwrap();
// Here we set the trancscript plaintext to just be zeroes, which is possible
// because it is not determined by the encodings so far.
let sent = vec![0_u8; transcript_refs.max_len(Direction::Sent)];
let recv = vec![0_u8; transcript_refs.max_len(Direction::Received)];
let (idx_sent, idx_recv) = proof
.verify_with_provider(&HashProvider::default(), &commitment, &sent, &recv)
.unwrap();
assert_eq!(idx_sent, idxs[0].1);
assert_eq!(idx_recv, idxs[1].1);
}
#[rstest]
fn test_encoding_provider(index: (RangeSet<usize>, RangeSet<usize>), encodings: Encodings) {
let (sent_range, recv_range) = index;
let Encodings { sent, recv } = encodings;
let provider = Provider::new(sent, &sent_range, recv, &recv_range);
let mut encodings_sent = Vec::new();
let mut encodings_recv = Vec::new();
provider
.provide_encoding(Direction::Sent, 16..24, &mut encodings_sent)
.unwrap();
provider
.provide_encoding(Direction::Received, 56..64, &mut encodings_recv)
.unwrap();
let expected_sent = generate_encodings((16..24).into());
let expected_recv = generate_encodings((56..64).into());
assert_eq!(expected_sent, encodings_sent);
assert_eq!(expected_recv, encodings_recv);
}
#[fixture]
fn transcript_refs(index: (RangeSet<usize>, RangeSet<usize>)) -> TranscriptRefs {
let mut transcript_refs = TranscriptRefs::new(40, 64);
let dummy = |range: Range<usize>| {
Vector::<U8>::from_raw(Slice::from_range_unchecked(8 * range.start..8 * range.end))
};
for range in index.0.iter_ranges() {
transcript_refs.add(Direction::Sent, &range, dummy(range.clone()));
}
for range in index.1.iter_ranges() {
transcript_refs.add(Direction::Received, &range, dummy(range.clone()));
}
transcript_refs
}
#[fixture]
fn encodings(index: (RangeSet<usize>, RangeSet<usize>)) -> Encodings {
let sent = generate_encodings(index.0);
let recv = generate_encodings(index.1);
Encodings { sent, recv }
}
#[fixture]
fn encoding_idxs(
index: (RangeSet<usize>, RangeSet<usize>),
) -> Vec<(Direction, RangeSet<usize>)> {
let (sent, recv) = index;
vec![(Direction::Sent, sent), (Direction::Received, recv)]
}
#[fixture]
fn index() -> (RangeSet<usize>, RangeSet<usize>) {
let mut sent = RangeSet::default();
sent.union_mut(&(0..8));
sent.union_mut(&(16..24));
sent.union_mut(&(32..40));
let mut recv = RangeSet::default();
recv.union_mut(&(40..48));
recv.union_mut(&(56..64));
(sent, recv)
}
#[derive(Clone, Copy)]
struct MockEncodingMemory;
impl EncodingMemory<Binary> for MockEncodingMemory {
fn get_encodings(&self, values: &[Vector<U8>]) -> Vec<u8> {
let ranges: Vec<Range<usize>> = values
.iter()
.map(|r| {
let range = r.to_raw().to_range();
range.start / 8..range.end / 8
})
.collect();
let ranges: RangeSet<usize> = ranges.into();
generate_encodings(ranges)
}
}
fn generate_encodings(index: RangeSet<usize>) -> Vec<u8> {
let mut out = Vec::new();
for el in index.iter() {
out.extend_from_slice(&[el as u8; ENCODING_SIZE]);
}
out
}
}

View File

@@ -3,7 +3,7 @@
use std::collections::HashMap;
use mpz_core::bitvec::BitVec;
use mpz_hash::{blake3::Blake3, sha256::Sha256};
use mpz_hash::sha256::Sha256;
use mpz_memory_core::{
DecodeFutureTyped, MemoryExt, Vector,
binary::{Binary, U8},
@@ -20,34 +20,171 @@ use tlsn_core::{
use crate::{Role, commit::transcript::TranscriptRefs};
/// Future which will resolve to the committed hash values.
/// Creates plaintext hashes.
#[derive(Debug)]
pub(crate) struct HashCommitFuture {
#[allow(clippy::type_complexity)]
futs: Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
DecodeFutureTyped<BitVec, Vec<u8>>,
)>,
pub(crate) struct PlaintextHasher {
ranges: Vec<HashRange>,
}
impl HashCommitFuture {
impl PlaintextHasher {
/// Creates a new instance.
///
/// # Arguments
///
/// * `indices` - The hash indices.
pub(crate) fn new<'a>(
indices: impl Iterator<Item = &'a (Direction, RangeSet<usize>, HashAlgId)>,
) -> Self {
let mut ranges = Vec::new();
for (direction, index, id) in indices {
let hash_range = HashRange::new(*direction, index.clone(), *id);
ranges.push(hash_range);
}
Self { ranges }
}
/// Prove plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn prove(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<(HashFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Prover, transcript_refs)?;
let mut futures = Vec::new();
let mut secrets = Vec::new();
for ((range, hash_ref), blinder_ref) in self.ranges.iter().zip(hash_refs).zip(blinders) {
let blinder: Blinder = rand::random();
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut));
secrets.push(PlaintextHashSecret {
direction: range.direction,
idx: range.range.clone(),
blinder,
alg: range.id,
});
}
let hashes = HashFuture { futures };
Ok((hashes, secrets))
}
/// Verify plaintext hash commitments.
///
/// # Arguments
///
/// * `vm` - The virtual machine.
/// * `transcript_refs` - The transcript references.
pub(crate) fn verify(
&self,
vm: &mut dyn Vm<Binary>,
transcript_refs: &TranscriptRefs,
) -> Result<HashFuture, HashCommitError> {
let (hash_refs, blinders) = commit(vm, &self.ranges, Role::Verifier, transcript_refs)?;
let mut futures = Vec::new();
for ((range, hash_ref), blinder) in self.ranges.iter().zip(hash_refs).zip(blinders) {
vm.commit(blinder)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futures.push((range.clone(), hash_fut))
}
let hashes = HashFuture { futures };
Ok(hashes)
}
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn commit(
vm: &mut dyn Vm<Binary>,
ranges: &[HashRange],
role: Role,
refs: &TranscriptRefs,
) -> Result<(Vec<Array<U8, 32>>, Vec<Vector<U8>>), HashCommitError> {
let mut hashers = HashMap::new();
let mut hash_refs = Vec::new();
let mut blinders = Vec::new();
for HashRange {
direction,
range,
id,
} in ranges.iter()
{
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
}
let hash = match *id {
HashAlgId::SHA256 => {
let mut hasher = if let Some(hasher) = hashers.get(id).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(id, hasher.clone());
hasher
};
for plaintext in refs.get(*direction, range) {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
id => {
return Err(HashCommitError::unsupported_alg(id));
}
};
hash_refs.push(hash);
blinders.push(blinder);
}
Ok((hash_refs, blinders))
}
/// Future which will resolve to the committed hash values.
#[derive(Debug)]
pub(crate) struct HashFuture {
futures: Vec<(HashRange, DecodeFutureTyped<BitVec, Vec<u8>>)>,
}
impl HashFuture {
/// Tries to receive the value, returning an error if the value is not
/// ready.
pub(crate) fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
let mut output = Vec::new();
for (direction, idx, alg, mut fut) in self.futs {
for (hash_range, mut fut) in self.futures {
let hash = fut
.try_recv()
.map_err(|_| HashCommitError::decode())?
.ok_or_else(HashCommitError::decode)?;
output.push(PlaintextHash {
direction,
idx,
direction: hash_range.direction,
idx: hash_range.range,
hash: TypedHash {
alg,
alg: hash_range.id,
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
},
});
@@ -57,133 +194,21 @@ impl HashCommitFuture {
}
}
/// Prove plaintext hash commitments.
pub(crate) fn prove_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<(HashCommitFuture, Vec<PlaintextHashSecret>), HashCommitError> {
let mut futs = Vec::new();
let mut secrets = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Prover, refs, idxs)?
{
let blinder: Blinder = rand::random();
#[derive(Debug, Clone)]
struct HashRange {
direction: Direction,
range: RangeSet<usize>,
id: HashAlgId,
}
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx.clone(), alg, hash_fut));
secrets.push(PlaintextHashSecret {
impl HashRange {
fn new(direction: Direction, range: RangeSet<usize>, id: HashAlgId) -> Self {
Self {
direction,
idx,
blinder,
alg,
});
}
Ok((HashCommitFuture { futs }, secrets))
}
/// Verify plaintext hash commitments.
pub(crate) fn verify_hash(
vm: &mut dyn Vm<Binary>,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<HashCommitFuture, HashCommitError> {
let mut futs = Vec::new();
for (direction, idx, alg, hash_ref, blinder_ref) in
hash_commit_inner(vm, Role::Verifier, refs, idxs)?
{
vm.commit(blinder_ref)?;
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
futs.push((direction, idx, alg, hash_fut));
}
Ok(HashCommitFuture { futs })
}
#[derive(Clone)]
enum Hasher {
Sha256(Sha256),
Blake3(Blake3),
}
/// Commit plaintext hashes of the transcript.
#[allow(clippy::type_complexity)]
fn hash_commit_inner(
vm: &mut dyn Vm<Binary>,
role: Role,
refs: &TranscriptRefs,
idxs: impl IntoIterator<Item = (Direction, RangeSet<usize>, HashAlgId)>,
) -> Result<
Vec<(
Direction,
RangeSet<usize>,
HashAlgId,
Array<U8, 32>,
Vector<U8>,
)>,
HashCommitError,
> {
let mut output = Vec::new();
let mut hashers = HashMap::new();
for (direction, idx, alg) in idxs {
let blinder = vm.alloc_vec::<U8>(16)?;
match role {
Role::Prover => vm.mark_private(blinder)?,
Role::Verifier => vm.mark_blind(blinder)?,
range,
id,
}
let hash = match alg {
HashAlgId::SHA256 => {
let mut hasher = if let Some(Hasher::Sha256(hasher)) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Sha256(hasher.clone()));
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
hasher.update(&plaintext);
}
hasher.update(&blinder);
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
HashAlgId::BLAKE3 => {
let mut hasher = if let Some(Hasher::Blake3(hasher)) = hashers.get(&alg).cloned() {
hasher
} else {
let hasher = Blake3::new(vm).map_err(HashCommitError::hasher)?;
hashers.insert(alg, Hasher::Blake3(hasher.clone()));
hasher
};
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
hasher
.update(vm, &plaintext)
.map_err(HashCommitError::hasher)?;
}
hasher
.update(vm, &blinder)
.map_err(HashCommitError::hasher)?;
hasher.finalize(vm).map_err(HashCommitError::hasher)?
}
alg => {
return Err(HashCommitError::unsupported_alg(alg));
}
};
output.push((direction, idx, alg, hash, blinder));
}
Ok(output)
}
/// Error type for hash commitments.
@@ -232,3 +257,148 @@ impl From<VmError> for HashCommitError {
Self(ErrorRepr::Vm(value))
}
}
#[cfg(test)]
mod test {
use crate::{
Role,
commit::{hash::PlaintextHasher, transcript::TranscriptRefs},
};
use mpz_common::context::test_st_context;
use mpz_garble_core::Delta;
use mpz_memory_core::{
MemoryExt, Vector, ViewExt,
binary::{Binary, U8},
};
use mpz_ot::ideal::rcot::{IdealRCOTReceiver, IdealRCOTSender, ideal_rcot};
use mpz_vm_core::{Execute, Vm};
use mpz_zk::{Prover, ProverConfig, Verifier, VerifierConfig};
use rand::{Rng, SeedableRng, rngs::StdRng};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use sha2::Digest;
use tlsn_core::{hash::HashAlgId, transcript::Direction};
#[rstest]
#[tokio::test]
async fn test_hasher() {
let mut sent1 = RangeSet::default();
sent1.union_mut(&(1..6));
sent1.union_mut(&(11..16));
let mut sent2 = RangeSet::default();
sent2.union_mut(&(22..25));
let mut recv = RangeSet::default();
recv.union_mut(&(20..25));
let hash_ranges = [
(Direction::Sent, sent1, HashAlgId::SHA256),
(Direction::Sent, sent2, HashAlgId::SHA256),
(Direction::Received, recv, HashAlgId::SHA256),
];
let mut refs_prover = TranscriptRefs::new(1000, 1000);
let mut refs_verifier = TranscriptRefs::new(1000, 1000);
let values = [
b"abcde".to_vec(),
b"vwxyz".to_vec(),
b"xxx".to_vec(),
b"12345".to_vec(),
];
let (mut ctx_p, mut ctx_v) = test_st_context(8);
let (mut prover, mut verifier) = vms();
let mut values_iter = values.iter();
for (direction, idx, _) in hash_ranges.iter() {
for range in idx.iter_ranges() {
let value = values_iter.next().unwrap();
let ref_prover = assign(Role::Prover, &mut prover, value.clone());
refs_prover.add(*direction, &range, ref_prover);
let ref_verifier = assign(Role::Verifier, &mut verifier, value.clone());
refs_verifier.add(*direction, &range, ref_verifier);
}
}
let hasher_prover = PlaintextHasher::new(hash_ranges.iter());
let hasher_verifier = PlaintextHasher::new(hash_ranges.iter());
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let (prover_hashes, prover_secrets) =
hasher_prover.prove(&mut prover, &refs_prover).unwrap();
let verifier_hashes = hasher_verifier
.verify(&mut verifier, &refs_verifier)
.unwrap();
tokio::try_join!(
prover.execute_all(&mut ctx_p),
verifier.execute_all(&mut ctx_v)
)
.unwrap();
let prover_hashes = prover_hashes.try_recv().unwrap();
let verifier_hashes = verifier_hashes.try_recv().unwrap();
assert_eq!(prover_hashes, verifier_hashes);
let values_per_commitment = [b"abcdevwxyz".to_vec(), b"xxx".to_vec(), b"12345".to_vec()];
for ((value, hash), secret) in values_per_commitment
.iter()
.zip(prover_hashes)
.zip(prover_secrets)
{
let blinder = secret.blinder.as_bytes();
let mut blinded_value = value.clone();
blinded_value.extend_from_slice(blinder);
let expected_hash = sha256(&blinded_value);
let hash: Vec<u8> = hash.hash.value.into();
assert_eq!(expected_hash, hash);
}
}
fn assign(role: Role, vm: &mut dyn Vm<Binary>, value: Vec<u8>) -> Vector<U8> {
let reference: Vector<U8> = vm.alloc_vec(value.len()).unwrap();
if let Role::Prover = role {
vm.mark_private(reference).unwrap();
vm.assign(reference, value).unwrap();
} else {
vm.mark_blind(reference).unwrap();
}
vm.commit(reference).unwrap();
reference
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = sha2::Sha256::default();
hasher.update(data);
hasher.finalize().as_slice().to_vec()
}
#[fixture]
fn vms() -> (Prover<IdealRCOTReceiver>, Verifier<IdealRCOTSender>) {
let mut rng = StdRng::seed_from_u64(0);
let delta = Delta::random(&mut rng);
let (ot_send, ot_recv) = ideal_rcot(rng.random(), delta.into_inner());
let prover = Prover::new(ProverConfig::default(), ot_recv);
let verifier = Verifier::new(VerifierConfig::default(), delta, ot_send);
(prover, verifier)
}
}

View File

@@ -1,211 +1,473 @@
use mpz_memory_core::{
MemoryExt, Vector,
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError};
use rangeset::{Intersection, RangeSet};
use tlsn_core::transcript::{Direction, PartialTranscript};
//! Transcript reference storage.
use std::ops::Range;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{Difference, Disjoint, RangeSet, Subset, UnionMut};
use tlsn_core::transcript::Direction;
/// References to the application plaintext in the transcript.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub(crate) struct TranscriptRefs {
sent: Vec<Vector<U8>>,
recv: Vec<Vector<U8>>,
sent: RefStorage,
recv: RefStorage,
}
impl TranscriptRefs {
pub(crate) fn new(sent: Vec<Vector<U8>>, recv: Vec<Vector<U8>>) -> Self {
/// Creates a new instance.
///
/// # Arguments
///
/// `sent_max_len` - The maximum length of the sent transcript in bytes.
/// `recv_max_len` - The maximum length of the received transcript in bytes.
pub(crate) fn new(sent_max_len: usize, recv_max_len: usize) -> Self {
let sent = RefStorage::new(sent_max_len);
let recv = RefStorage::new(recv_max_len);
Self { sent, recv }
}
/// Returns the sent plaintext references.
pub(crate) fn sent(&self) -> &[Vector<U8>] {
&self.sent
/// Adds new references to the transcript refs.
///
/// New transcript references are only added if none of them are already
/// present.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
/// * `refs` - The new transcript refs.
pub(crate) fn add(&mut self, direction: Direction, index: &Range<usize>, refs: Vector<U8>) {
match direction {
Direction::Sent => self.sent.add(index, refs),
Direction::Received => self.recv.add(index, refs),
}
}
/// Returns the received plaintext references.
pub(crate) fn recv(&self) -> &[Vector<U8>] {
&self.recv
/// Marks references of the transcript as decoded.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn mark_decoded(&mut self, direction: Direction, index: &RangeSet<usize>) {
match direction {
Direction::Sent => self.sent.mark_decoded(index),
Direction::Received => self.recv.mark_decoded(index),
}
}
/// Returns the transcript lengths.
pub(crate) fn len(&self) -> (usize, usize) {
let sent = self.sent.iter().map(|v| v.len()).sum();
let recv = self.recv.iter().map(|v| v.len()).sum();
(sent, recv)
/// Returns plaintext references for some index.
///
/// Queries that cannot or only partially be satisfied will return an empty
/// vector.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn get(&self, direction: Direction, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
match direction {
Direction::Sent => self.sent.get(index),
Direction::Received => self.recv.get(index),
}
}
/// Returns VM references for the given direction and index, otherwise
/// `None` if the index is out of bounds.
pub(crate) fn get(
/// Computes the subset of `index` which is missing.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
/// * `index` - The index of the transcript references.
pub(crate) fn compute_missing(
&self,
direction: Direction,
idx: &RangeSet<usize>,
) -> Option<Vec<Vector<U8>>> {
if idx.is_empty() {
return Some(Vec::new());
index: &RangeSet<usize>,
) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.compute_missing(index),
Direction::Received => self.recv.compute_missing(index),
}
}
/// Returns the maximum length of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn max_len(&self, direction: Direction) -> usize {
match direction {
Direction::Sent => self.sent.max_len(),
Direction::Received => self.recv.max_len(),
}
}
/// Returns the decoded ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
pub(crate) fn decoded(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.decoded(),
Direction::Received => self.recv.decoded(),
}
}
/// Returns the set ranges of the transcript.
///
/// # Arguments
///
/// * `direction` - The direction of the transcript.
#[cfg(test)]
pub(crate) fn index(&self, direction: Direction) -> RangeSet<usize> {
match direction {
Direction::Sent => self.sent.index(),
Direction::Received => self.recv.index(),
}
}
}
/// Inner storage for transcript references.
///
/// Saves transcript references by maintaining an `index` and an `offset`. The
/// offset translates from `index` to some memory location and contains
/// information about possibly non-contigious memory locations. The storage is
/// bit-addressed but the API works with ranges over bytes.
#[derive(Debug, Clone)]
struct RefStorage {
index: RangeSet<usize>,
decoded: RangeSet<usize>,
offset: Vec<isize>,
max_len: usize,
}
impl RefStorage {
fn new(max_len: usize) -> Self {
Self {
index: RangeSet::default(),
decoded: RangeSet::default(),
offset: Vec::default(),
max_len: 8 * max_len,
}
}
fn add(&mut self, index: &Range<usize>, data: Vector<U8>) {
assert!(
index.start < index.end,
"Range should be valid for adding to reference storage"
);
assert_eq!(
index.len(),
data.len(),
"Provided index and vm references should have the same length"
);
let bit_index = 8 * index.start..8 * index.end;
assert!(
bit_index.is_disjoint(&self.index),
"Parts of the provided index have already been computed"
);
assert!(
bit_index.end <= self.max_len,
"Provided index should be smaller than max_len"
);
if bit_index.end > self.offset.len() {
self.offset.resize(bit_index.end, 0);
}
let refs = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let mem_address = data.to_raw().ptr().as_usize() as isize;
let offset = mem_address - bit_index.start as isize;
// Computes the transcript range for each reference.
let mut start = 0;
let mut slice_iter = refs.iter().map(move |slice| {
let out = (slice, start..start + slice.len());
start += slice.len();
out
});
self.index.union_mut(&bit_index);
self.offset[bit_index].fill(offset);
}
let mut slices = Vec::new();
let (mut slice, mut slice_range) = slice_iter.next()?;
for range in idx.iter_ranges() {
loop {
if let Some(intersection) = slice_range.intersection(&range) {
let start = intersection.start - slice_range.start;
let end = intersection.end - slice_range.start;
slices.push(slice.get(start..end).expect("range should be in bounds"));
fn mark_decoded(&mut self, index: &RangeSet<usize>) {
let bit_index = to_bit_index(index);
self.decoded.union_mut(&bit_index);
}
fn get(&self, index: &RangeSet<usize>) -> Vec<Vector<U8>> {
let bit_index = to_bit_index(index);
if bit_index.is_empty() || !bit_index.is_subset(&self.index) {
return Vec::new();
}
// Partition rangeset into ranges mapping to possibly disjunct memory locations.
//
// If the offset changes during iteration of a single range, it means that the
// backing memory is non-contigious and we need to split that range.
let mut transcript_refs = Vec::new();
for idx in bit_index.iter_ranges() {
let mut start = idx.start;
let mut end = idx.start;
let mut offset = self.offset[start];
for k in idx {
let next_offset = self.offset[k];
if next_offset == offset {
end += 1;
continue;
}
// Proceed to next range if the current slice extends beyond. Otherwise, proceed
// to the next slice.
if range.end <= slice_range.end {
break;
} else {
(slice, slice_range) = slice_iter.next()?;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
start = k;
end = k + 1;
offset = next_offset;
}
let len = end - start;
let ptr = (start as isize + offset) as usize;
let mem_ref = Slice::from_range_unchecked(ptr..ptr + len);
transcript_refs.push(Vector::from_raw(mem_ref));
}
Some(slices)
transcript_refs
}
fn compute_missing(&self, index: &RangeSet<usize>) -> RangeSet<usize> {
let byte_index = to_byte_index(&self.index);
index.difference(&byte_index)
}
fn decoded(&self) -> RangeSet<usize> {
to_byte_index(&self.decoded)
}
fn max_len(&self) -> usize {
self.max_len / 8
}
#[cfg(test)]
fn index(&self) -> RangeSet<usize> {
to_byte_index(&self.index)
}
}
/// Decodes the transcript.
pub(crate) fn decode_transcript(
vm: &mut dyn Vm<Binary>,
sent: &RangeSet<usize>,
recv: &RangeSet<usize>,
refs: &TranscriptRefs,
) -> Result<(), VmError> {
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, recv)
.expect("index is in bounds");
fn to_bit_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut bit_index = RangeSet::default();
for slice in sent_refs.into_iter().chain(recv_refs) {
// Drop the future, we don't need it.
drop(vm.decode(slice)?);
for r in index.iter_ranges() {
bit_index.union_mut(&(8 * r.start..8 * r.end));
}
Ok(())
bit_index
}
/// Verifies a partial transcript.
pub(crate) fn verify_transcript(
vm: &mut dyn Vm<Binary>,
transcript: &PartialTranscript,
refs: &TranscriptRefs,
) -> Result<(), InconsistentTranscript> {
let sent_refs = refs
.get(Direction::Sent, transcript.sent_authed())
.expect("index is in bounds");
let recv_refs = refs
.get(Direction::Received, transcript.received_authed())
.expect("index is in bounds");
fn to_byte_index(index: &RangeSet<usize>) -> RangeSet<usize> {
let mut byte_index = RangeSet::default();
let mut authenticated_data = Vec::new();
for data in sent_refs.into_iter().chain(recv_refs) {
let plaintext = vm
.get(data)
.expect("reference is valid")
.expect("plaintext is decoded");
authenticated_data.extend_from_slice(&plaintext);
for r in index.iter_ranges() {
let start = r.start;
let end = r.end;
assert!(
start.trailing_zeros() >= 3,
"start range should be divisible by 8"
);
assert!(
end.trailing_zeros() >= 3,
"end range should be divisible by 8"
);
let start = start >> 3;
let end = end >> 3;
byte_index.union_mut(&(start..end));
}
let mut purported_data = Vec::with_capacity(authenticated_data.len());
for range in transcript.sent_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
}
for range in transcript.received_authed().iter_ranges() {
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
}
if purported_data != authenticated_data {
return Err(InconsistentTranscript {});
}
Ok(())
byte_index
}
/// Error for [`verify_transcript`].
#[derive(Debug, thiserror::Error)]
#[error("inconsistent transcript")]
pub(crate) struct InconsistentTranscript {}
#[cfg(test)]
mod tests {
use super::TranscriptRefs;
use mpz_memory_core::{FromRaw, Slice, Vector, binary::U8};
use rangeset::RangeSet;
use crate::commit::transcript::RefStorage;
use mpz_memory_core::{FromRaw, Slice, ToRaw, Vector, binary::U8};
use rangeset::{RangeSet, UnionMut};
use rstest::{fixture, rstest};
use std::ops::Range;
use tlsn_core::transcript::Direction;
// TRANSCRIPT_REFS:
//
// 48..96 -> 6 slots
// 112..176 -> 8 slots
// 240..288 -> 6 slots
// 352..392 -> 5 slots
// 440..480 -> 5 slots
const TRANSCRIPT_REFS: &[Range<usize>] = &[48..96, 112..176, 240..288, 352..392, 440..480];
#[rstest]
fn test_storage_add(
max_len: usize,
ranges: [Range<usize>; 6],
offsets: [isize; 6],
storage: RefStorage,
) {
let bit_ranges: Vec<Range<usize>> = ranges.iter().map(|r| 8 * r.start..8 * r.end).collect();
let bit_offsets: Vec<isize> = offsets.iter().map(|o| 8 * o).collect();
const IDXS: &[Range<usize>] = &[0..4, 5..10, 14..16, 16..28];
let mut expected_index: RangeSet<usize> = RangeSet::default();
// 1. Take slots 0..4, 4 slots -> 48..80 (4)
// 2. Take slots 5..10, 5 slots -> 88..96 (1) + 112..144 (4)
// 3. Take slots 14..16, 2 slots -> 240..256 (2)
// 4. Take slots 16..28, 12 slots -> 256..288 (4) + 352..392 (5) + 440..464 (3)
//
// 5. Merge slots 240..256 and 256..288 => 240..288 and get EXPECTED_REFS
const EXPECTED_REFS: &[Range<usize>] =
&[48..80, 88..96, 112..144, 240..288, 352..392, 440..464];
expected_index.union_mut(&bit_ranges[0]);
expected_index.union_mut(&bit_ranges[1]);
#[test]
fn test_transcript_refs_get() {
let transcript_refs: Vec<Vector<U8>> = TRANSCRIPT_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_index.union_mut(&bit_ranges[2]);
expected_index.union_mut(&bit_ranges[3]);
let transcript_refs = TranscriptRefs {
sent: transcript_refs.clone(),
recv: transcript_refs,
};
expected_index.union_mut(&bit_ranges[4]);
expected_index.union_mut(&bit_ranges[5]);
assert_eq!(storage.index, expected_index);
let vm_refs = transcript_refs
.get(Direction::Sent, &RangeSet::from(IDXS))
.unwrap();
let end = expected_index.end().unwrap();
let mut expected_offset = vec![0_isize; end];
let expected_refs: Vec<Vector<U8>> = EXPECTED_REFS
.iter()
.cloned()
.map(|range| Vector::from_raw(Slice::from_range_unchecked(range)))
.collect();
expected_offset[bit_ranges[0].clone()].fill(bit_offsets[0]);
expected_offset[bit_ranges[1].clone()].fill(bit_offsets[1]);
assert_eq!(
vm_refs.len(),
expected_refs.len(),
"Length of actual and expected refs are not equal"
);
expected_offset[bit_ranges[2].clone()].fill(bit_offsets[2]);
expected_offset[bit_ranges[3].clone()].fill(bit_offsets[3]);
for (&expected, actual) in expected_refs.iter().zip(vm_refs) {
assert_eq!(expected, actual);
expected_offset[bit_ranges[4].clone()].fill(bit_offsets[4]);
expected_offset[bit_ranges[5].clone()].fill(bit_offsets[5]);
assert_eq!(storage.offset, expected_offset);
assert_eq!(storage.decoded, RangeSet::default());
assert_eq!(storage.max_len, 8 * max_len);
}
#[rstest]
fn test_storage_get(ranges: [Range<usize>; 6], offsets: [isize; 6], storage: RefStorage) {
let mut index = RangeSet::default();
ranges.iter().for_each(|r| index.union_mut(r));
let data = storage.get(&index);
let mut data_recovered = Vec::new();
for (r, o) in ranges.iter().zip(offsets) {
data_recovered.push(vec(r.start as isize + o..r.end as isize + o));
}
// Merge possibly adjacent vectors.
//
// Two vectors are adjacent if
//
// - vectors are adjacent in memory.
// - transcript ranges of those vectors are adjacent, too.
let mut range_iter = ranges.iter();
let mut vec_iter = data_recovered.iter();
let mut data_expected = Vec::new();
let mut current_vec = vec_iter.next().unwrap().to_raw().to_range();
let mut current_range = range_iter.next().unwrap();
for (r, v) in range_iter.zip(vec_iter) {
let v_range = v.to_raw().to_range();
let start = v_range.start;
let end = v_range.end;
if current_vec.end == start && current_range.end == r.start {
current_vec.end = end;
} else {
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
current_vec = start..end;
current_range = r;
}
}
let v = Vector::<U8>::from_raw(Slice::from_range_unchecked(current_vec));
data_expected.push(v);
assert_eq!(data, data_expected);
}
#[rstest]
fn test_storage_compute_missing(storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(6..12));
range.union_mut(&(18..21));
range.union_mut(&(22..25));
range.union_mut(&(50..60));
let missing = storage.compute_missing(&range);
let mut missing_expected = RangeSet::default();
missing_expected.union_mut(&(8..12));
missing_expected.union_mut(&(20..21));
missing_expected.union_mut(&(50..60));
assert_eq!(missing, missing_expected);
}
#[rstest]
fn test_mark_decoded(mut storage: RefStorage) {
let mut range = RangeSet::default();
range.union_mut(&(14..17));
range.union_mut(&(30..37));
storage.mark_decoded(&range);
let decoded = storage.decoded();
assert_eq!(range, decoded);
}
#[fixture]
fn max_len() -> usize {
1000
}
#[fixture]
fn ranges() -> [Range<usize>; 6] {
let r1 = 0..5;
let r2 = 5..8;
let r3 = 12..20;
let r4 = 22..26;
let r5 = 30..35;
let r6 = 35..38;
[r1, r2, r3, r4, r5, r6]
}
#[fixture]
fn offsets() -> [isize; 6] {
[7, 9, 20, 18, 30, 30]
}
// expected memory ranges: 8 * ranges + 8 * offsets
// 1. 56..96 do not merge with next one, because not adjacent in memory
// 2. 112..136
// 3. 256..320 do not merge with next one, adjacent in memory, but the ranges
// itself are not
// 4. 320..352
// 5. 480..520 merge with next one
// 6 520..544
//
//
// 1. 56..96, length: 5
// 2. 112..136, length: 3
// 3. 256..320, length: 8
// 4. 320..352, length: 4
// 5. 480..544, length: 8
#[fixture]
fn storage(max_len: usize, ranges: [Range<usize>; 6], offsets: [isize; 6]) -> RefStorage {
let [r1, r2, r3, r4, r5, r6] = ranges;
let [o1, o2, o3, o4, o5, o6] = offsets;
let mut storage = RefStorage::new(max_len);
storage.add(&r1, vec(r1.start as isize + o1..r1.end as isize + o1));
storage.add(&r2, vec(r2.start as isize + o2..r2.end as isize + o2));
storage.add(&r3, vec(r3.start as isize + o3..r3.end as isize + o3));
storage.add(&r4, vec(r4.start as isize + o4..r4.end as isize + o4));
storage.add(&r5, vec(r5.start as isize + o5..r5.end as isize + o5));
storage.add(&r6, vec(r6.start as isize + o6..r6.end as isize + o6));
storage
}
fn vec(range: Range<isize>) -> Vector<U8> {
let range = 8 * range.start as usize..8 * range.end as usize;
Vector::from_raw(Slice::from_range_unchecked(range))
}
}

View File

@@ -1,249 +0,0 @@
//! Encoding commitment protocol.
use std::ops::Range;
use mpz_common::Context;
use mpz_memory_core::{
Vector,
binary::U8,
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{
hash::HashAlgorithm,
transcript::{
Direction,
encoding::{
Encoder, EncoderSecret, EncodingCommitment, EncodingProvider, EncodingProviderError,
EncodingTree, EncodingTreeError, new_encoder,
},
},
};
use crate::commit::transcript::TranscriptRefs;
/// Bytes of encoding, per byte.
const ENCODING_SIZE: usize = 128;
#[derive(Debug, Serialize, Deserialize)]
struct Encodings {
sent: Vec<u8>,
recv: Vec<u8>,
}
/// Transfers the encodings using the provided seed and keys.
///
/// The keys must be consistent with the global delta used in the encodings.
pub(crate) async fn transfer<'a>(
ctx: &mut Context,
refs: &TranscriptRefs,
delta: &Delta,
f: impl Fn(Vector<U8>) -> &'a [Key],
) -> Result<EncodingCommitment, EncodingError> {
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
let encoder = new_encoder(&secret);
let sent_keys: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
let recv_keys: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|key| key.as_block().as_bytes())
.copied()
.collect();
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
encoder.encode_range(
Direction::Sent,
0..sent_keys.len() / ENCODING_SIZE,
&mut sent_encoding,
);
encoder.encode_range(
Direction::Received,
0..recv_keys.len() / ENCODING_SIZE,
&mut recv_encoding,
);
sent_encoding
.iter_mut()
.zip(sent_keys)
.for_each(|(enc, key)| *enc ^= key);
recv_encoding
.iter_mut()
.zip(recv_keys)
.for_each(|(enc, key)| *enc ^= key);
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
ctx.io_mut()
.with_limit(frame_limit)
.send(Encodings {
sent: sent_encoding,
recv: recv_encoding,
})
.await?;
let root = ctx.io_mut().expect_next().await?;
ctx.io_mut().send(secret.clone()).await?;
Ok(EncodingCommitment {
root,
secret: secret.clone(),
})
}
/// Receives the encodings using the provided MACs.
///
/// The MACs must be consistent with the global delta used in the encodings.
pub(crate) async fn receive<'a>(
ctx: &mut Context,
hasher: &(dyn HashAlgorithm + Send + Sync),
refs: &TranscriptRefs,
f: impl Fn(Vector<U8>) -> &'a [Mac],
idxs: impl IntoIterator<Item = &(Direction, RangeSet<usize>)>,
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
// Set frame limit and add some extra bytes cushion room.
let (sent, recv) = refs.len();
let frame_limit = ENCODING_SIZE * (sent + recv) + ctx.io().limit();
let Encodings { mut sent, mut recv } =
ctx.io_mut().with_limit(frame_limit).expect_next().await?;
let sent_macs: Vec<u8> = refs
.sent()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
let recv_macs: Vec<u8> = refs
.recv()
.iter()
.copied()
.flat_map(&f)
.flat_map(|mac| mac.as_bytes())
.copied()
.collect();
assert_eq!(sent_macs.len() % ENCODING_SIZE, 0);
assert_eq!(recv_macs.len() % ENCODING_SIZE, 0);
if sent.len() != sent_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Sent,
expected: sent_macs.len(),
got: sent.len(),
}
.into());
}
if recv.len() != recv_macs.len() {
return Err(ErrorRepr::IncorrectMacCount {
direction: Direction::Received,
expected: recv_macs.len(),
got: recv.len(),
}
.into());
}
sent.iter_mut()
.zip(sent_macs)
.for_each(|(enc, mac)| *enc ^= mac);
recv.iter_mut()
.zip(recv_macs)
.for_each(|(enc, mac)| *enc ^= mac);
let provider = Provider { sent, recv };
let tree = EncodingTree::new(hasher, idxs, &provider)?;
let root = tree.root();
ctx.io_mut().send(root.clone()).await?;
let secret = ctx.io_mut().expect_next().await?;
let commitment = EncodingCommitment { root, secret };
Ok((commitment, tree))
}
#[derive(Debug)]
struct Provider {
sent: Vec<u8>,
recv: Vec<u8>,
}
impl EncodingProvider for Provider {
fn provide_encoding(
&self,
direction: Direction,
range: Range<usize>,
dest: &mut Vec<u8>,
) -> Result<(), EncodingProviderError> {
let encodings = match direction {
Direction::Sent => &self.sent,
Direction::Received => &self.recv,
};
let start = range.start * ENCODING_SIZE;
let end = range.end * ENCODING_SIZE;
if end > encodings.len() {
return Err(EncodingProviderError);
}
dest.extend_from_slice(&encodings[start..end]);
Ok(())
}
}
/// Encoding protocol error.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct EncodingError(#[from] ErrorRepr);
#[derive(Debug, thiserror::Error)]
#[error("encoding protocol error: {0}")]
enum ErrorRepr {
#[error("I/O error: {0}")]
Io(std::io::Error),
#[error("incorrect MAC count for {direction}: expected {expected}, got {got}")]
IncorrectMacCount {
direction: Direction,
expected: usize,
got: usize,
},
#[error("encoding tree error: {0}")]
EncodingTree(EncodingTreeError),
}
impl From<std::io::Error> for EncodingError {
fn from(value: std::io::Error) -> Self {
Self(ErrorRepr::Io(value))
}
}
impl From<EncodingTreeError> for EncodingError {
fn from(value: EncodingTreeError) -> Self {
Self(ErrorRepr::EncodingTree(value))
}
}

View File

@@ -23,11 +23,11 @@ pub(crate) fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
// Pad data to be a multiple of 16 bytes.
let aad_padded_block_count = (aad.len() / 16) + !aad.len().is_multiple_of(16) as usize;
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
aad.resize(aad_padded_block_count * 16, 0);
let ciphertext_padded_block_count =
(ciphertext.len() / 16) + !ciphertext.len().is_multiple_of(16) as usize;
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);

View File

@@ -7,7 +7,6 @@
pub(crate) mod commit;
pub mod config;
pub(crate) mod context;
pub(crate) mod encoding;
pub(crate) mod ghash;
pub(crate) mod msg;
pub(crate) mod mux;

View File

@@ -6,7 +6,6 @@ use tlsn_core::connection::{HandshakeData, ServerName};
/// Message sent from Prover to Verifier to prove the server identity.
#[derive(Debug, Serialize, Deserialize)]
#[allow(dead_code)]
pub(crate) struct ServerIdentityProof {
/// Server name.
pub name: ServerName,

View File

@@ -8,49 +8,41 @@ pub mod state;
pub use config::{ProverConfig, ProverConfigBuilder, TlsConfig, TlsConfigBuilder};
pub use error::ProverError;
pub use future::ProverFuture;
use rustls_pki_types::CertificateDer;
pub use tlsn_core::{ProveConfig, ProveConfigBuilder, ProveConfigBuilderError, ProverOutput};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use mpz_common::Context;
use mpz_core::Block;
use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::ProverConfig as ZkProverConfig;
use rand::Rng;
use rustls_pki_types::CertificateDer;
use serio::SinkExt;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tlsn_core::{
ProveRequest,
connection::{HandshakeData, ServerName},
transcript::{TlsTranscript, Transcript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
use webpki::anchor_from_trusted_cert;
use crate::{
Role,
commit::{
commit_records,
hash::prove_hash,
transcript::{TranscriptRefs, decode_transcript},
},
commit::{ProvingState, TranscriptRefs},
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{LeaderCtrl, MpcTlsLeader, SessionKeys};
use rand::Rng;
use serio::SinkExt;
use std::sync::Arc;
use tls_client::{ClientConnection, ServerName as TlsServerName};
use tls_client_async::{TlsConnection, bind_client};
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
connection::{HandshakeData, ServerName},
hash::{Blake3, HashAlgId, HashAlgorithm, Keccak256, Sha256},
transcript::{TlsTranscript, Transcript, TranscriptCommitment, TranscriptSecret},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Instrument, Span, debug, info, info_span, instrument};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>,
mpz_core::Block,
@@ -173,8 +165,8 @@ impl Prover<state::Setup> {
mux_ctrl,
mut mux_fut,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
vm,
..
@@ -281,28 +273,6 @@ impl Prover<state::Setup> {
)
.map_err(ProverError::zk)?;
// Prove received plaintext. Prover drops the proof output, as
// they trust themselves.
let (sent_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
let (recv_refs, _) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(ProverError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(ProverError::zk))
.await?;
@@ -310,7 +280,9 @@ impl Prover<state::Setup> {
let transcript = tls_transcript
.to_transcript()
.expect("transcript is complete");
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let (sent_len, recv_len) = transcript.len();
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Prover {
config: self.config,
@@ -323,6 +295,10 @@ impl Prover<state::Setup> {
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred: false,
},
})
}
@@ -356,7 +332,7 @@ impl Prover<state::Committed> {
///
/// * `config` - The disclosure configuration.
#[instrument(parent = &self.span, level = "info", skip_all, err)]
pub async fn prove(&mut self, config: &ProveConfig) -> Result<ProverOutput, ProverError> {
pub async fn prove(&mut self, config: ProveConfig) -> Result<ProverOutput, ProverError> {
let state::Committed {
mux_fut,
ctx,
@@ -364,114 +340,48 @@ impl Prover<state::Committed> {
tls_transcript,
transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
encodings_transferred,
..
} = &mut self.state;
let mut output = ProverOutput {
transcript_commitments: Vec::new(),
transcript_secrets: Vec::new(),
};
// Create and send prove payload.
let server_name = self.config.server_name();
let handshake = config
.server_identity()
.then(|| (server_name.clone(), HandshakeData::new(tls_transcript)));
let partial_transcript = if let Some((sent, recv)) = config.reveal() {
decode_transcript(vm, sent, recv, transcript_refs).map_err(ProverError::zk)?;
Some(transcript.to_partial(sent.clone(), recv.clone()))
let partial = if let Some((reveal_sent, reveal_recv)) = config.reveal() {
Some(transcript.to_partial(reveal_sent.clone(), reveal_recv.clone()))
} else {
None
};
let payload = ProvePayload {
handshake: config.server_identity().then(|| {
(
self.config.server_name().clone(),
HandshakeData {
certs: tls_transcript
.server_cert_chain()
.expect("server cert chain is present")
.to_vec(),
sig: tls_transcript
.server_signature()
.expect("server signature is present")
.clone(),
binding: tls_transcript.certificate_binding().clone(),
},
)
}),
transcript: partial_transcript,
transcript_commit: config.transcript_commit().map(|config| config.to_request()),
};
let payload = ProveRequest::new(&config, partial, handshake);
// Send payload.
mux_fut
.poll_with(ctx.io_mut().send(payload).map_err(ProverError::from))
.await?;
let mut hash_commitments = None;
if let Some(commit_config) = config.transcript_commit() {
if commit_config.has_encoding() {
let hasher: &(dyn HashAlgorithm + Send + Sync) =
match *commit_config.encoding_hash_alg() {
HashAlgId::SHA256 => &Sha256::default(),
HashAlgId::KECCAK256 => &Keccak256::default(),
HashAlgId::BLAKE3 => &Blake3::default(),
alg => {
return Err(ProverError::config(format!(
"unsupported hash algorithm for encoding commitment: {alg}"
)));
}
};
let proving_state = ProvingState::for_prover(
config,
tls_transcript,
transcript,
transcript_refs,
*encodings_transferred,
);
let (commitment, tree) = mux_fut
.poll_with(
encoding::receive(
ctx,
hasher,
transcript_refs,
|plaintext| vm.get_macs(plaintext).expect("reference is valid"),
commit_config.iter_encoding(),
)
.map_err(ProverError::commit),
)
.await?;
output
.transcript_commitments
.push(TranscriptCommitment::Encoding(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Encoding(tree));
}
if commit_config.has_hash() {
hash_commitments = Some(
prove_hash(
vm,
transcript_refs,
commit_config
.iter_hash()
.map(|((dir, idx), alg)| (*dir, idx.clone(), *alg)),
)
.map_err(ProverError::commit)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(ProverError::zk))
let (output, encodings_executed) = mux_fut
.poll_with(
proving_state
.prove(vm, ctx, zk_aes_ctr_sent, zk_aes_ctr_recv, keys.clone())
.map_err(ProverError::from),
)
.await?;
if let Some((hash_fut, hash_secrets)) = hash_commitments {
let hash_commitments = hash_fut.try_recv().map_err(ProverError::commit)?;
for (commitment, secret) in hash_commitments.into_iter().zip(hash_secrets) {
output
.transcript_commitments
.push(TranscriptCommitment::Hash(commitment));
output
.transcript_secrets
.push(TranscriptSecret::Hash(secret));
}
}
*encodings_transferred = encodings_executed;
Ok(output)
}

View File

@@ -1,8 +1,6 @@
use std::{error::Error, fmt};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use std::{error::Error, fmt};
/// Error for [`Prover`](crate::Prover).
#[derive(Debug, thiserror::Error)]
@@ -42,13 +40,6 @@ impl ProverError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn commit<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Commit, source)
}
}
#[derive(Debug)]
@@ -116,8 +107,8 @@ impl From<ZkAesCtrError> for ProverError {
}
}
impl From<EncodingError> for ProverError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for ProverError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -9,7 +9,7 @@ use tlsn_deap::Deap;
use tokio::sync::Mutex;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
prover::{Mpc, Zk},
zk_aes_ctr::ZkAesCtr,
@@ -42,6 +42,10 @@ pub struct Committed {
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript: Transcript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -1,11 +1,9 @@
//! Verifier.
pub(crate) mod config;
mod config;
mod error;
pub mod state;
use std::sync::Arc;
pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError};
pub use error::VerifierError;
pub use tlsn_core::{
@@ -13,20 +11,8 @@ pub use tlsn_core::{
webpki::ServerCertVerifier,
};
use crate::{
Role,
commit::{
commit_records,
hash::verify_hash,
transcript::{TranscriptRefs, decode_transcript, verify_transcript},
},
config::ProtocolConfig,
context::build_mt_context,
encoding,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
use std::sync::Arc;
use futures::{AsyncRead, AsyncWrite, TryFutureExt};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
@@ -35,17 +21,25 @@ use mpz_garble_core::Delta;
use mpz_vm_core::prelude::*;
use mpz_zk::VerifierConfig as ZkVerifierConfig;
use serio::stream::IoStreamExt;
use tls_core::msgs::enums::ContentType;
use tlsn_core::{
ProvePayload,
ProveRequest,
connection::{ConnectionInfo, ServerName},
transcript::{TlsTranscript, TranscriptCommitment},
transcript::{ContentType, TlsTranscript},
};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
use tracing::{Span, debug, info, info_span, instrument};
use crate::{
Role,
commit::{ProvingState, TranscriptRefs},
config::ProtocolConfig,
context::build_mt_context,
mux::attach_mux,
tag::verify_tags,
zk_aes_ctr::ZkAesCtr,
};
pub(crate) type RCOTSender = mpz_ot::rcot::shared::SharedRCOTSender<
mpz_ot::ferret::Sender<mpz_ot::kos::Sender<mpz_ot::chou_orlandi::Receiver>>,
mpz_core::Block,
@@ -188,8 +182,8 @@ impl Verifier<state::Setup> {
mut mux_fut,
delta,
mpc_tls,
mut zk_aes_ctr_sent,
mut zk_aes_ctr_recv,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
vm,
keys,
} = self.state;
@@ -230,27 +224,6 @@ impl Verifier<state::Setup> {
)
.map_err(VerifierError::zk)?;
// Prepare for the prover to prove received plaintext.
let (sent_refs, sent_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_sent,
tls_transcript
.sent()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
let (recv_refs, recv_proof) = commit_records(
&mut vm,
&mut zk_aes_ctr_recv,
tls_transcript
.recv()
.iter()
.filter(|record| record.typ == ContentType::ApplicationData),
)
.map_err(VerifierError::zk)?;
mux_fut
.poll_with(vm.execute_all(&mut ctx).map_err(VerifierError::zk))
.await?;
@@ -260,11 +233,30 @@ impl Verifier<state::Setup> {
// authenticated from the verifier's perspective.
tag_proof.verify().map_err(VerifierError::zk)?;
// Verify the plaintext proofs.
sent_proof.verify().map_err(VerifierError::zk)?;
recv_proof.verify().map_err(VerifierError::zk)?;
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if matches!(record.typ, ContentType::ApplicationData) {
Some(record.ciphertext.len())
} else {
None
}
})
.sum();
let transcript_refs = TranscriptRefs::new(sent_refs, recv_refs);
let transcript_refs = TranscriptRefs::new(sent_len, recv_len);
Ok(Verifier {
config: self.config,
@@ -277,6 +269,11 @@ impl Verifier<state::Setup> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name: None,
encodings_transferred: false,
},
})
}
@@ -305,126 +302,42 @@ impl Verifier<state::Committed> {
vm,
tls_transcript,
transcript_refs,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys,
verified_server_name,
encodings_transferred,
..
} = &mut self.state;
let ProvePayload {
handshake,
transcript,
transcript_commit,
} = mux_fut
let payload: ProveRequest = mux_fut
.poll_with(ctx.io_mut().expect_next().map_err(VerifierError::from))
.await?;
let verifier = if let Some(root_store) = self.config.root_store() {
ServerCertVerifier::new(root_store).map_err(VerifierError::config)?
} else {
ServerCertVerifier::mozilla()
};
let proving_state = ProvingState::for_verifier(
payload,
tls_transcript,
transcript_refs,
verified_server_name.clone(),
*encodings_transferred,
);
let server_name = if let Some((name, cert_data)) = handshake {
cert_data
.verify(
&verifier,
tls_transcript.time(),
tls_transcript.server_ephemeral_key(),
&name,
)
.map_err(VerifierError::verify)?;
Some(name)
} else {
None
};
if let Some(partial_transcript) = &transcript {
let sent_len = tls_transcript
.sent()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
let recv_len = tls_transcript
.recv()
.iter()
.filter_map(|record| {
if let ContentType::ApplicationData = record.typ {
Some(record.ciphertext.len())
} else {
None
}
})
.sum::<usize>();
// Check ranges.
if partial_transcript.len_sent() != sent_len
|| partial_transcript.len_received() != recv_len
{
return Err(VerifierError::verify(
"prover sent transcript with incorrect length",
));
}
decode_transcript(
let (output, encodings_executed) = mux_fut
.poll_with(proving_state.verify(
vm,
partial_transcript.sent_authed(),
partial_transcript.received_authed(),
transcript_refs,
)
.map_err(VerifierError::zk)?;
}
let mut transcript_commitments = Vec::new();
let mut hash_commitments = None;
if let Some(commit_config) = transcript_commit {
if commit_config.encoding() {
let commitment = mux_fut
.poll_with(encoding::transfer(
ctx,
transcript_refs,
delta,
|plaintext| vm.get_keys(plaintext).expect("reference is valid"),
))
.await?;
transcript_commitments.push(TranscriptCommitment::Encoding(commitment));
}
if commit_config.has_hash() {
hash_commitments = Some(
verify_hash(vm, transcript_refs, commit_config.iter_hash().cloned())
.map_err(VerifierError::verify)?,
);
}
}
mux_fut
.poll_with(vm.execute_all(ctx).map_err(VerifierError::zk))
ctx,
zk_aes_ctr_sent,
zk_aes_ctr_recv,
keys.clone(),
*delta,
self.config.root_store(),
))
.await?;
// Verify revealed data.
if let Some(partial_transcript) = &transcript {
verify_transcript(vm, partial_transcript, transcript_refs)
.map_err(VerifierError::verify)?;
}
*verified_server_name = output.server_name.clone();
*encodings_transferred = encodings_executed;
if let Some(hash_commitments) = hash_commitments {
for commitment in hash_commitments.try_recv().map_err(VerifierError::verify)? {
transcript_commitments.push(TranscriptCommitment::Hash(commitment));
}
}
Ok(VerifierOutput {
server_name,
transcript,
transcript_commitments,
})
Ok(output)
}
/// Closes the connection with the prover.

View File

@@ -12,7 +12,7 @@ use crate::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator};
#[builder(pattern = "owned")]
pub struct VerifierConfig {
protocol_config_validator: ProtocolConfigValidator,
#[builder(default, setter(strip_option))]
#[builder(setter(strip_option))]
root_store: Option<RootCertStore>,
}

View File

@@ -1,4 +1,4 @@
use crate::{encoding::EncodingError, zk_aes_ctr::ZkAesCtrError};
use crate::{commit::CommitError, zk_aes_ctr::ZkAesCtrError};
use mpc_tls::MpcTlsError;
use std::{error::Error, fmt};
@@ -20,13 +20,6 @@ impl VerifierError {
}
}
pub(crate) fn config<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Config, source)
}
pub(crate) fn mpc<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
@@ -40,13 +33,6 @@ impl VerifierError {
{
Self::new(ErrorKind::Zk, source)
}
pub(crate) fn verify<E>(source: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync + 'static>>,
{
Self::new(ErrorKind::Verify, source)
}
}
#[derive(Debug)]
@@ -56,7 +42,6 @@ enum ErrorKind {
Mpc,
Zk,
Commit,
Verify,
}
impl fmt::Display for VerifierError {
@@ -69,7 +54,6 @@ impl fmt::Display for VerifierError {
ErrorKind::Mpc => f.write_str("mpc error")?,
ErrorKind::Zk => f.write_str("zk error")?,
ErrorKind::Commit => f.write_str("commit error")?,
ErrorKind::Verify => f.write_str("verification error")?,
}
if let Some(source) = &self.source {
@@ -116,8 +100,8 @@ impl From<ZkAesCtrError> for VerifierError {
}
}
impl From<EncodingError> for VerifierError {
fn from(e: EncodingError) -> Self {
impl From<CommitError> for VerifierError {
fn from(e: CommitError) -> Self {
Self::new(ErrorKind::Commit, e)
}
}

View File

@@ -3,14 +3,14 @@
use std::sync::Arc;
use crate::{
commit::transcript::TranscriptRefs,
commit::TranscriptRefs,
mux::{MuxControl, MuxFuture},
zk_aes_ctr::ZkAesCtr,
};
use mpc_tls::{MpcTlsFollower, SessionKeys};
use mpz_common::Context;
use mpz_memory_core::correlated::Delta;
use tlsn_core::transcript::TlsTranscript;
use tlsn_core::{connection::ServerName, transcript::TlsTranscript};
use tlsn_deap::Deap;
use tokio::sync::Mutex;
@@ -45,6 +45,11 @@ pub struct Committed {
pub(crate) vm: Zk,
pub(crate) tls_transcript: TlsTranscript,
pub(crate) transcript_refs: TranscriptRefs,
pub(crate) zk_aes_ctr_sent: ZkAesCtr,
pub(crate) zk_aes_ctr_recv: ZkAesCtr,
pub(crate) keys: SessionKeys,
pub(crate) verified_server_name: Option<ServerName>,
pub(crate) encodings_transferred: bool,
}
opaque_debug::implement!(Committed);

View File

@@ -37,8 +37,8 @@ impl ZkAesCtr {
}
/// Returns the role.
pub(crate) fn role(&self) -> &Role {
&self.role
pub(crate) fn role(&self) -> Role {
self.role
}
/// Allocates `len` bytes for encryption.

View File

@@ -103,7 +103,7 @@ async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(verifier_soc
let config = builder.build().unwrap();
prover.prove(&config).await.unwrap();
prover.prove(config).await.unwrap();
prover.close().await.unwrap();
}

View File

@@ -126,7 +126,7 @@ impl JsProver {
let config = builder.build()?;
prover.prove(&config).await?;
prover.prove(config).await?;
prover.close().await?;
info!("Finalized");