mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-12 16:18:43 -05:00
Compare commits
4 Commits
plot_py
...
feat/integ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1ee91e63b | ||
|
|
57f85dcdf2 | ||
|
|
727dff3ac9 | ||
|
|
830d0ab0d5 |
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -21,7 +21,7 @@ env:
|
||||
# - https://github.com/privacy-ethereum/mpz/issues/178
|
||||
# 32 seems to be big enough for the foreseeable future
|
||||
RAYON_NUM_THREADS: 32
|
||||
RUST_VERSION: 1.92.0
|
||||
RUST_VERSION: 1.91.1
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
|
||||
1748
Cargo.lock
generated
1748
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
30
Cargo.toml
30
Cargo.toml
@@ -66,21 +66,21 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn = { path = "crates/tlsn" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-predicate = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", branch = "feat/mpz-bool-type" }
|
||||
|
||||
rangeset = { version = "0.4" }
|
||||
serio = { version = "0.2" }
|
||||
|
||||
@@ -27,6 +27,7 @@ tlsn-data-fixtures = { workspace = true, optional = true }
|
||||
tlsn-tls-core = { workspace = true, features = ["serde"] }
|
||||
tlsn-utils = { workspace = true }
|
||||
rangeset = { workspace = true, features = ["serde"] }
|
||||
mpz-predicate = { workspace = true }
|
||||
|
||||
aead = { workspace = true, features = ["alloc"], optional = true }
|
||||
aes-gcm = { workspace = true, optional = true }
|
||||
|
||||
@@ -1,16 +1,119 @@
|
||||
//! Proving configuration.
|
||||
|
||||
use mpz_predicate::Pred;
|
||||
use rangeset::set::{RangeSet, ToRangeSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
|
||||
|
||||
/// Configuration to prove information to the verifier.
|
||||
/// Configuration for a predicate to prove over transcript data.
|
||||
///
|
||||
/// A predicate is a boolean constraint that operates on transcript bytes.
|
||||
/// The prover proves in ZK that the predicate evaluates to true.
|
||||
///
|
||||
/// The predicate itself encodes which byte indices it operates on via its
|
||||
/// atomic comparisons (e.g., `gte(42, threshold)` operates on byte index 42).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PredicateConfig {
|
||||
/// Human-readable name for the predicate (sent to verifier as sanity
|
||||
/// check).
|
||||
name: String,
|
||||
/// Direction of transcript data the predicate operates on.
|
||||
direction: Direction,
|
||||
/// The predicate to prove.
|
||||
predicate: Pred,
|
||||
}
|
||||
|
||||
impl PredicateConfig {
|
||||
/// Creates a new predicate configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Human-readable name for the predicate.
|
||||
/// * `direction` - Whether the predicate operates on sent or received data.
|
||||
/// * `predicate` - The predicate to prove.
|
||||
pub fn new(name: impl Into<String>, direction: Direction, predicate: Pred) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
direction,
|
||||
predicate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the predicate name.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the direction of transcript data.
|
||||
pub fn direction(&self) -> Direction {
|
||||
self.direction
|
||||
}
|
||||
|
||||
/// Returns the predicate.
|
||||
pub fn predicate(&self) -> &Pred {
|
||||
&self.predicate
|
||||
}
|
||||
|
||||
/// Returns the transcript byte indices this predicate operates on.
|
||||
pub fn indices(&self) -> Vec<usize> {
|
||||
self.predicate.indices()
|
||||
}
|
||||
|
||||
/// Converts to a request (wire format).
|
||||
pub fn to_request(&self) -> PredicateRequest {
|
||||
let indices: RangeSet<usize> = self
|
||||
.predicate
|
||||
.indices()
|
||||
.into_iter()
|
||||
.map(|idx| idx..idx + 1)
|
||||
.collect();
|
||||
PredicateRequest {
|
||||
name: self.name.clone(),
|
||||
direction: self.direction,
|
||||
indices,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wire format for predicate proving request.
|
||||
///
|
||||
/// Contains only the predicate name and indices - the verifier is expected
|
||||
/// to know which predicate corresponds to the name from out-of-band agreement.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredicateRequest {
|
||||
/// Human-readable name for the predicate.
|
||||
name: String,
|
||||
/// Direction of transcript data the predicate operates on.
|
||||
direction: Direction,
|
||||
/// Transcript byte indices the predicate operates on.
|
||||
indices: RangeSet<usize>,
|
||||
}
|
||||
|
||||
impl PredicateRequest {
|
||||
/// Returns the predicate name.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the direction of transcript data.
|
||||
pub fn direction(&self) -> Direction {
|
||||
self.direction
|
||||
}
|
||||
|
||||
/// Returns the transcript byte indices as a RangeSet.
|
||||
pub fn indices(&self) -> &RangeSet<usize> {
|
||||
&self.indices
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration to prove information to the verifier.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProveConfig {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
predicates: Vec<PredicateConfig>,
|
||||
}
|
||||
|
||||
impl ProveConfig {
|
||||
@@ -35,6 +138,11 @@ impl ProveConfig {
|
||||
self.transcript_commit.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the predicate configurations.
|
||||
pub fn predicates(&self) -> &[PredicateConfig] {
|
||||
&self.predicates
|
||||
}
|
||||
|
||||
/// Returns a request.
|
||||
pub fn to_request(&self) -> ProveRequest {
|
||||
ProveRequest {
|
||||
@@ -44,6 +152,7 @@ impl ProveConfig {
|
||||
.transcript_commit
|
||||
.clone()
|
||||
.map(|config| config.to_request()),
|
||||
predicates: self.predicates.iter().map(|p| p.to_request()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -55,6 +164,7 @@ pub struct ProveConfigBuilder<'a> {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitConfig>,
|
||||
predicates: Vec<PredicateConfig>,
|
||||
}
|
||||
|
||||
impl<'a> ProveConfigBuilder<'a> {
|
||||
@@ -65,6 +175,7 @@ impl<'a> ProveConfigBuilder<'a> {
|
||||
server_identity: false,
|
||||
reveal: None,
|
||||
transcript_commit: None,
|
||||
predicates: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,12 +248,52 @@ impl<'a> ProveConfigBuilder<'a> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Adds a predicate to prove over transcript data.
|
||||
///
|
||||
/// The predicate encodes which byte indices it operates on via its atomic
|
||||
/// comparisons (e.g., `gte(42, threshold)` operates on byte index 42).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Human-readable name for the predicate (sent to verifier as
|
||||
/// sanity check).
|
||||
/// * `direction` - Whether the predicate operates on sent or received data.
|
||||
/// * `predicate` - The predicate to prove.
|
||||
pub fn predicate(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
direction: Direction,
|
||||
predicate: Pred,
|
||||
) -> Result<&mut Self, ProveConfigError> {
|
||||
let indices = predicate.indices();
|
||||
|
||||
// Predicate must reference at least one transcript byte.
|
||||
let last_idx = *indices
|
||||
.last()
|
||||
.ok_or(ProveConfigError(ErrorRepr::EmptyPredicate))?;
|
||||
|
||||
// Since indices are sorted, only check the last one for bounds.
|
||||
let transcript_len = self.transcript.len_of_direction(direction);
|
||||
if last_idx >= transcript_len {
|
||||
return Err(ProveConfigError(ErrorRepr::IndexOutOfBounds {
|
||||
direction,
|
||||
actual: last_idx,
|
||||
len: transcript_len,
|
||||
}));
|
||||
}
|
||||
|
||||
self.predicates
|
||||
.push(PredicateConfig::new(name, direction, predicate));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Builds the configuration.
|
||||
pub fn build(self) -> Result<ProveConfig, ProveConfigError> {
|
||||
Ok(ProveConfig {
|
||||
server_identity: self.server_identity,
|
||||
reveal: self.reveal,
|
||||
transcript_commit: self.transcript_commit,
|
||||
predicates: self.predicates,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,6 +304,7 @@ pub struct ProveRequest {
|
||||
server_identity: bool,
|
||||
reveal: Option<(RangeSet<usize>, RangeSet<usize>)>,
|
||||
transcript_commit: Option<TranscriptCommitRequest>,
|
||||
predicates: Vec<PredicateRequest>,
|
||||
}
|
||||
|
||||
impl ProveRequest {
|
||||
@@ -171,6 +323,11 @@ impl ProveRequest {
|
||||
pub fn transcript_commit(&self) -> Option<&TranscriptCommitRequest> {
|
||||
self.transcript_commit.as_ref()
|
||||
}
|
||||
|
||||
/// Returns the predicate requests.
|
||||
pub fn predicates(&self) -> &[PredicateRequest] {
|
||||
&self.predicates
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`ProveConfig`].
|
||||
@@ -180,10 +337,12 @@ pub struct ProveConfigError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ErrorRepr {
|
||||
#[error("range is out of bounds of the transcript ({direction}): {actual} > {len}")]
|
||||
#[error("index out of bounds for {direction} transcript: {actual} >= {len}")]
|
||||
IndexOutOfBounds {
|
||||
direction: Direction,
|
||||
actual: usize,
|
||||
len: usize,
|
||||
},
|
||||
#[error("predicate must reference at least one transcript byte")]
|
||||
EmptyPredicate,
|
||||
}
|
||||
|
||||
@@ -110,14 +110,8 @@ impl Transcript {
|
||||
}
|
||||
|
||||
Some(
|
||||
Subsequence::new(
|
||||
idx.clone(),
|
||||
data.index(idx).fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
)
|
||||
.expect("data is same length as index"),
|
||||
Subsequence::new(idx.clone(), data.index(idx).flatten().copied().collect())
|
||||
.expect("data is same length as index"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -196,20 +190,18 @@ pub struct CompressedPartialTranscript {
|
||||
impl From<PartialTranscript> for CompressedPartialTranscript {
|
||||
fn from(uncompressed: PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
|
||||
Vec::new(),
|
||||
|mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
},
|
||||
),
|
||||
sent_authed: uncompressed
|
||||
.sent
|
||||
.index(&uncompressed.sent_authed_idx)
|
||||
.flatten()
|
||||
.copied()
|
||||
.collect(),
|
||||
received_authed: uncompressed
|
||||
.received
|
||||
.index(&uncompressed.received_authed_idx)
|
||||
.fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
.flatten()
|
||||
.copied()
|
||||
.collect(),
|
||||
sent_idx: uncompressed.sent_authed_idx,
|
||||
recv_idx: uncompressed.received_authed_idx,
|
||||
sent_total: uncompressed.sent.len(),
|
||||
|
||||
@@ -10,10 +10,13 @@ workspace = true
|
||||
[dependencies]
|
||||
tlsn = { workspace = true }
|
||||
tlsn-formats = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tls-server-fixture = { workspace = true }
|
||||
tlsn-server-fixture = { workspace = true }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
spansy = { workspace = true }
|
||||
mpz-predicate = { workspace = true }
|
||||
rangeset = { workspace = true }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
@@ -61,3 +64,7 @@ path = "attestation/present.rs"
|
||||
[[example]]
|
||||
name = "attestation_verify"
|
||||
path = "attestation/verify.rs"
|
||||
|
||||
[[example]]
|
||||
name = "interactive_predicate"
|
||||
path = "interactive_predicate/interactive_predicate.rs"
|
||||
|
||||
@@ -5,6 +5,7 @@ This folder contains examples demonstrating how to use the TLSNotary protocol.
|
||||
* [Interactive](./interactive/README.md): Interactive Prover and Verifier session without a trusted notary.
|
||||
* [Attestation](./attestation/README.md): Performing a simple notarization with a trusted notary.
|
||||
* [Interactive_zk](./interactive_zk/README.md): Interactive Prover and Verifier session demonstrating zero-knowledge age verification using Noir.
|
||||
* [Interactive_predicate](./interactive_predicate/README.md): Interactive session demonstrating predicate proving over transcript data (e.g., proving a JSON field is a valid string without revealing it).
|
||||
|
||||
|
||||
Refer to <https://tlsnotary.org/docs/quick_start> for a quick start guide to using TLSNotary with these examples.
|
||||
29
crates/examples/interactive_predicate/README.md
Normal file
29
crates/examples/interactive_predicate/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
## Interactive Predicate: Proving Predicates over Transcript Data
|
||||
|
||||
This example demonstrates how to use TLSNotary to prove predicates (boolean constraints) over transcript bytes in zero knowledge, without revealing the actual data.
|
||||
|
||||
In this example:
|
||||
- The server returns JSON data containing a "name" field with a string value
|
||||
- The Prover proves that the name value is a valid JSON string without revealing it
|
||||
- The Verifier learns that the string is valid JSON, but not the actual content
|
||||
|
||||
This uses `mpz_predicate` to build predicates that operate on transcript bytes. The predicate is compiled to a circuit and executed in the ZK VM to prove satisfaction.
|
||||
|
||||
### Running the Example
|
||||
|
||||
First, start the test server from the root of this repository:
|
||||
```shell
|
||||
RUST_LOG=info PORT=4000 cargo run --bin tlsn-server-fixture
|
||||
```
|
||||
|
||||
Next, run the interactive predicate example:
|
||||
```shell
|
||||
SERVER_PORT=4000 cargo run --release --example interactive_predicate
|
||||
```
|
||||
|
||||
To view more detailed debug information:
|
||||
```shell
|
||||
RUST_LOG=debug,yamux=info,uid_mux=info SERVER_PORT=4000 cargo run --release --example interactive_predicate
|
||||
```
|
||||
|
||||
> Note: In this example, the Prover and Verifier run on the same machine. In real-world scenarios, they would typically operate on separate machines.
|
||||
368
crates/examples/interactive_predicate/interactive_predicate.rs
Normal file
368
crates/examples/interactive_predicate/interactive_predicate.rs
Normal file
@@ -0,0 +1,368 @@
|
||||
//! Example demonstrating predicate proving over transcript data.
|
||||
//!
|
||||
//! This example shows how a prover can prove a predicate (boolean constraint)
|
||||
//! over transcript bytes in zero knowledge, without revealing the actual data.
|
||||
//!
|
||||
//! In this example:
|
||||
//! - The server returns JSON data containing a "name" field with a string value
|
||||
//! - The prover proves that the name value is a valid JSON string without
|
||||
//! revealing it
|
||||
//! - The verifier learns that the string is valid JSON, but not the actual
|
||||
//! content
|
||||
|
||||
use std::{
|
||||
env,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Bytes, Request, StatusCode, Uri};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use mpz_predicate::{json::validate_string, Pred};
|
||||
use rangeset::prelude::RangeSet;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
|
||||
use tracing::instrument;
|
||||
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
prover::ProverConfig,
|
||||
tls::TlsClientConfig,
|
||||
tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig, TlsCommitProtocolConfig},
|
||||
verifier::VerifierConfig,
|
||||
},
|
||||
connection::ServerName,
|
||||
prover::Prover,
|
||||
transcript::Direction,
|
||||
verifier::{Verifier, VerifierOutput},
|
||||
webpki::{CertificateDer, RootCertStore},
|
||||
};
|
||||
use tlsn_server_fixture::DEFAULT_FIXTURE_PORT;
|
||||
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
|
||||
/// Predicate name for JSON string validation (both parties agree on this
|
||||
/// out-of-band).
|
||||
const JSON_STRING_PREDICATE: &str = "valid_json_string";
|
||||
|
||||
// Maximum number of bytes that can be sent from prover to server.
|
||||
const MAX_SENT_DATA: usize = 1 << 12;
|
||||
// Maximum number of bytes that can be received by prover from server.
|
||||
const MAX_RECV_DATA: usize = 1 << 14;
|
||||
|
||||
/// Builds a predicate that validates a JSON string at the given indices.
|
||||
///
|
||||
/// Uses mpz_predicate's `validate_string` to ensure the bytes form a valid
|
||||
/// JSON string (proper escaping, valid UTF-8, no control characters, etc.).
|
||||
fn build_json_string_predicate(indices: &RangeSet<usize>) -> Pred {
|
||||
validate_string(indices.clone())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let server_host: String = env::var("SERVER_HOST").unwrap_or("127.0.0.1".into());
|
||||
let server_port: u16 = env::var("SERVER_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_FIXTURE_PORT);
|
||||
|
||||
// Use the JSON endpoint that returns data.
|
||||
let uri = format!("https://{SERVER_DOMAIN}:{server_port}/formats/json");
|
||||
let server_ip: IpAddr = server_host.parse().expect("Invalid IP address");
|
||||
let server_addr = SocketAddr::from((server_ip, server_port));
|
||||
|
||||
// Connect prover and verifier.
|
||||
let (prover_socket, verifier_socket) = tokio::io::duplex(1 << 23);
|
||||
let prover = prover(prover_socket, &server_addr, &uri);
|
||||
let verifier = verifier(verifier_socket);
|
||||
|
||||
match tokio::try_join!(prover, verifier) {
|
||||
Ok(_) => println!("\nSuccess! The prover proved that a JSON field contains a valid string without revealing it."),
|
||||
Err(e) => eprintln!("Error: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the value of a JSON field in the response body.
|
||||
/// Returns (start_index, end_index) of the value (excluding quotes for
|
||||
/// strings).
|
||||
fn find_json_string_value(data: &[u8], field_name: &str) -> Option<(usize, usize)> {
|
||||
let search_pattern = format!("\"{}\":", field_name);
|
||||
let pattern_bytes = search_pattern.as_bytes();
|
||||
|
||||
// Find the field name
|
||||
let field_pos = data
|
||||
.windows(pattern_bytes.len())
|
||||
.position(|w| w == pattern_bytes)?;
|
||||
|
||||
// Skip past the field name and colon
|
||||
let mut pos = field_pos + pattern_bytes.len();
|
||||
|
||||
// Skip whitespace
|
||||
while pos < data.len() && (data[pos] == b' ' || data[pos] == b'\n' || data[pos] == b'\r') {
|
||||
pos += 1;
|
||||
}
|
||||
|
||||
// Check if it's a string (starts with quote)
|
||||
if pos >= data.len() || data[pos] != b'"' {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip opening quote
|
||||
let start = pos + 1;
|
||||
|
||||
// Find closing quote (handling escapes)
|
||||
let mut end = start;
|
||||
while end < data.len() {
|
||||
if data[end] == b'\\' {
|
||||
// Skip escaped character
|
||||
end += 2;
|
||||
} else if data[end] == b'"' {
|
||||
break;
|
||||
} else {
|
||||
end += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Some((start, end))
|
||||
}
|
||||
|
||||
#[instrument(skip(verifier_socket))]
|
||||
async fn prover<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
verifier_socket: T,
|
||||
server_addr: &SocketAddr,
|
||||
uri: &str,
|
||||
) -> Result<()> {
|
||||
let uri = uri.parse::<Uri>().unwrap();
|
||||
assert_eq!(uri.scheme().unwrap().as_str(), "https");
|
||||
let server_domain = uri.authority().unwrap().host();
|
||||
|
||||
// Create a new prover and perform necessary setup.
|
||||
let prover = Prover::new(ProverConfig::builder().build()?)
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
.max_sent_data(tlsn_examples::MAX_SENT_DATA)
|
||||
.max_recv_data(tlsn_examples::MAX_RECV_DATA)
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
verifier_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Open a TCP connection to the server.
|
||||
let client_socket = tokio::net::TcpStream::connect(server_addr).await?;
|
||||
|
||||
// Bind the prover to the server connection.
|
||||
let (tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
let tls_connection = TokioIo::new(tls_connection.compat());
|
||||
|
||||
// Spawn the Prover to run in the background.
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
// MPC-TLS Handshake.
|
||||
let (mut request_sender, connection) =
|
||||
hyper::client::conn::http1::handshake(tls_connection).await?;
|
||||
|
||||
// Spawn the connection to run in the background.
|
||||
tokio::spawn(connection);
|
||||
|
||||
// Send request for JSON data.
|
||||
let request = Request::builder()
|
||||
.uri(uri.clone())
|
||||
.header("Host", server_domain)
|
||||
.header("Connection", "close")
|
||||
.method("GET")
|
||||
.body(Empty::<Bytes>::new())?;
|
||||
let response = request_sender.send_request(request).await?;
|
||||
|
||||
assert!(response.status() == StatusCode::OK);
|
||||
|
||||
// Create proof for the Verifier.
|
||||
let mut prover = prover_task.await??;
|
||||
|
||||
// Find the "name" field value in the JSON response
|
||||
let received = prover.transcript().received();
|
||||
|
||||
// Find the HTTP body (after \r\n\r\n)
|
||||
let body_start = received
|
||||
.windows(4)
|
||||
.position(|w| w == b"\r\n\r\n")
|
||||
.map(|p| p + 4)
|
||||
.unwrap_or(0);
|
||||
|
||||
// Find the "name" field's string value
|
||||
let (value_start, value_end) =
|
||||
find_json_string_value(&received[body_start..], "name").expect("should find name field");
|
||||
|
||||
// Adjust to absolute positions in transcript
|
||||
let value_start = body_start + value_start;
|
||||
let value_end = body_start + value_end;
|
||||
|
||||
let value_bytes = &received[value_start..value_end];
|
||||
println!(
|
||||
"Prover: Found 'name' field value: \"{}\" at positions {}..{}",
|
||||
String::from_utf8_lossy(value_bytes),
|
||||
value_start,
|
||||
value_end
|
||||
);
|
||||
println!("Prover: Will prove this is a valid JSON string without revealing the actual content");
|
||||
|
||||
// Build indices for the predicate as a RangeSet
|
||||
let indices: RangeSet<usize> = (value_start..value_end).into();
|
||||
|
||||
// Build the predicate using mpz_predicate
|
||||
let predicate = build_json_string_predicate(&indices);
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
// Reveal the server identity.
|
||||
builder.server_identity();
|
||||
|
||||
// Reveal the sent data (the request).
|
||||
builder.reveal_sent(&(0..prover.transcript().sent().len()))?;
|
||||
|
||||
// Reveal everything EXCEPT the string value we're proving the predicate over.
|
||||
if value_start > 0 {
|
||||
builder.reveal_recv(&(0..value_start))?;
|
||||
}
|
||||
if value_end < prover.transcript().received().len() {
|
||||
builder.reveal_recv(&(value_end..prover.transcript().received().len()))?;
|
||||
}
|
||||
|
||||
// Add the predicate to prove the string is valid JSON without revealing the
|
||||
// value.
|
||||
builder.predicate(JSON_STRING_PREDICATE, Direction::Received, predicate)?;
|
||||
|
||||
let config = builder.build()?;
|
||||
|
||||
prover.prove(&config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
println!("Prover: Successfully proved the predicate!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(socket))]
|
||||
async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
socket: T,
|
||||
) -> Result<()> {
|
||||
let verifier_config = VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?;
|
||||
let verifier = Verifier::new(verifier_config);
|
||||
|
||||
// Validate the proposed configuration and run the TLS commitment protocol.
|
||||
let verifier = verifier.commit(socket.compat()).await?;
|
||||
|
||||
// Validate configuration.
|
||||
let reject = if let TlsCommitProtocolConfig::Mpc(mpc_tls_config) = verifier.request().protocol()
|
||||
{
|
||||
if mpc_tls_config.max_sent_data() > MAX_SENT_DATA {
|
||||
Some("max_sent_data is too large")
|
||||
} else if mpc_tls_config.max_recv_data() > MAX_RECV_DATA {
|
||||
Some("max_recv_data is too large")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
Some("expecting to use MPC-TLS")
|
||||
};
|
||||
|
||||
if reject.is_some() {
|
||||
verifier.reject(reject).await?;
|
||||
return Err(anyhow::anyhow!("protocol configuration rejected"));
|
||||
}
|
||||
|
||||
// Run the TLS commitment protocol.
|
||||
let verifier = verifier.accept().await?.run().await?;
|
||||
|
||||
// Validate the proving request.
|
||||
let verifier = verifier.verify().await?;
|
||||
|
||||
// Check that server identity is being proven.
|
||||
if !verifier.request().server_identity() {
|
||||
let verifier = verifier
|
||||
.reject(Some("expecting to verify the server name"))
|
||||
.await?;
|
||||
verifier.close().await?;
|
||||
return Err(anyhow::anyhow!("prover did not reveal the server name"));
|
||||
}
|
||||
|
||||
// Check if predicates are requested and validate them.
|
||||
let predicates = verifier.request().predicates();
|
||||
if !predicates.is_empty() {
|
||||
println!(
|
||||
"Verifier: Prover requested {} predicate(s):",
|
||||
predicates.len()
|
||||
);
|
||||
for pred in predicates {
|
||||
println!(
|
||||
" - '{}' on {:?} at {} indices",
|
||||
pred.name(),
|
||||
pred.direction(),
|
||||
pred.indices().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Define the predicate resolver - this maps predicate names to predicates.
|
||||
// The resolver receives the predicate name and the indices from the prover's
|
||||
// request.
|
||||
let predicate_resolver = |name: &str, indices: &RangeSet<usize>| -> Option<Pred> {
|
||||
match name {
|
||||
JSON_STRING_PREDICATE => {
|
||||
// Build the JSON string validation predicate with the provided indices
|
||||
Some(build_json_string_predicate(indices))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
};
|
||||
|
||||
// Accept with predicate verification.
|
||||
let (
|
||||
VerifierOutput {
|
||||
server_name,
|
||||
transcript,
|
||||
..
|
||||
},
|
||||
verifier,
|
||||
) = verifier
|
||||
.accept_with_predicates(Some(&predicate_resolver))
|
||||
.await?;
|
||||
|
||||
verifier.close().await?;
|
||||
|
||||
let server_name = server_name.expect("prover should have revealed server name");
|
||||
let transcript = transcript.expect("prover should have revealed transcript data");
|
||||
|
||||
// Verify server name.
|
||||
let ServerName::Dns(server_name) = server_name;
|
||||
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
|
||||
|
||||
// The verifier can see the response but with the predicated string redacted.
|
||||
let received = transcript.received_unsafe();
|
||||
let redacted = String::from_utf8_lossy(received).replace('\0', "[REDACTED]");
|
||||
println!("Verifier: Received data (string value redacted):\n{redacted}");
|
||||
|
||||
println!("Verifier: Predicate verified successfully!");
|
||||
println!("Verifier: The hidden value is proven to be a valid JSON string");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -22,10 +22,7 @@ pub enum CmdOutput {
|
||||
GetTests(Vec<String>),
|
||||
Test(TestOutput),
|
||||
Bench(BenchOutput),
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
Fail {
|
||||
reason: Option<String>,
|
||||
},
|
||||
Fail { reason: Option<String> },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -7,9 +7,10 @@ publish = false
|
||||
[dependencies]
|
||||
tlsn-harness-core = { workspace = true }
|
||||
# tlsn-server-fixture = { workspace = true }
|
||||
charming = { version = "0.6.0", features = ["ssr"] }
|
||||
charming = { version = "0.5.1", features = ["ssr"] }
|
||||
csv = "1.3.0"
|
||||
clap = { workspace = true, features = ["derive", "env"] }
|
||||
polars = { version = "0.44", features = ["csv", "lazy"] }
|
||||
itertools = "0.14.0"
|
||||
toml = { workspace = true }
|
||||
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# TLSNotary Benchmark Plot Tool
|
||||
|
||||
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
- `<TOML>` - Path to Bench.toml file defining benchmark structure
|
||||
- `<CSV>...` - One or more CSV files with benchmark results
|
||||
|
||||
### Options
|
||||
|
||||
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
|
||||
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
|
||||
- Number of labels must match number of CSV files
|
||||
- `--min-max-band` - Add min/max bands to plots showing variance
|
||||
- `-h, --help` - Print help information
|
||||
|
||||
## Examples
|
||||
|
||||
### Single Dataset
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml results.csv
|
||||
```
|
||||
|
||||
Generates plots from a single benchmark run.
|
||||
|
||||
### Compare Two Runs
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml before.csv after.csv \
|
||||
--labels "Before Optimization" "After Optimization"
|
||||
```
|
||||
|
||||
Overlays two datasets to compare performance improvements.
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
|
||||
--labels "Native" "Browser" "WASM"
|
||||
```
|
||||
|
||||
Compare three different runtime environments.
|
||||
|
||||
### With Min/Max Bands
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml run1.csv run2.csv \
|
||||
--labels "Config A" "Config B" \
|
||||
--min-max-band
|
||||
```
|
||||
|
||||
Shows variance ranges for each dataset.
|
||||
|
||||
## Output Files
|
||||
|
||||
The tool generates two files per benchmark group:
|
||||
|
||||
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
|
||||
- `<output>.svg` - Static SVG image for documentation
|
||||
|
||||
Default output filenames:
|
||||
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
|
||||
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
|
||||
|
||||
## Plot Format
|
||||
|
||||
Each dataset displays:
|
||||
- **Solid line** - Total runtime (preprocessing + online phase)
|
||||
- **Dashed line** - Online phase only
|
||||
- **Shaded area** (optional) - Min/max variance bands
|
||||
|
||||
Different datasets automatically use distinct colors for easy comparison.
|
||||
|
||||
## CSV Format
|
||||
|
||||
Expected columns in each CSV file:
|
||||
- `group` - Benchmark group name (must match TOML)
|
||||
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
|
||||
- `latency` - Network latency in ms (for latency plots)
|
||||
- `time_preprocess` - Preprocessing time in ms
|
||||
- `time_online` - Online phase time in ms
|
||||
- `time_total` - Total runtime in ms
|
||||
|
||||
## TOML Format
|
||||
|
||||
The benchmark TOML file defines groups with either:
|
||||
|
||||
```toml
|
||||
[[group]]
|
||||
name = "my_benchmark"
|
||||
protocol_latency = 50 # Fixed latency for bandwidth plots
|
||||
# OR
|
||||
bandwidth = 10000 # Fixed bandwidth for latency plots
|
||||
```
|
||||
|
||||
All datasets must use the same TOML file to ensure consistent benchmark structure.
|
||||
|
||||
## Tips
|
||||
|
||||
- Use descriptive labels to make plots self-documenting
|
||||
- Keep CSV files from the same benchmark configuration for valid comparisons
|
||||
- Min/max bands are useful for showing stability but can clutter plots with many datasets
|
||||
- Interactive HTML plots support zooming and hovering for detailed values
|
||||
@@ -1,18 +1,17 @@
|
||||
use std::f32;
|
||||
|
||||
use charming::{
|
||||
Chart, HtmlRenderer, ImageRenderer,
|
||||
Chart, HtmlRenderer,
|
||||
component::{Axis, Legend, Title},
|
||||
element::{
|
||||
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
|
||||
Trigger,
|
||||
},
|
||||
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
|
||||
series::Line,
|
||||
theme::Theme,
|
||||
};
|
||||
use clap::Parser;
|
||||
use harness_core::bench::BenchItems;
|
||||
use polars::prelude::*;
|
||||
use harness_core::bench::{BenchItems, Measurement};
|
||||
use itertools::Itertools;
|
||||
|
||||
const THEME: Theme = Theme::Default;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about)]
|
||||
@@ -20,131 +19,72 @@ struct Cli {
|
||||
/// Path to the Bench.toml file with benchmark spec
|
||||
toml: String,
|
||||
|
||||
/// Paths to CSV files with benchmark results (one or more)
|
||||
csv: Vec<String>,
|
||||
/// Path to the CSV file with benchmark results
|
||||
csv: String,
|
||||
|
||||
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
|
||||
#[arg(short, long, num_args = 0..)]
|
||||
labels: Vec<String>,
|
||||
/// Prover kind: native or browser
|
||||
#[arg(short, long, value_enum, default_value = "native")]
|
||||
prover_kind: ProverKind,
|
||||
|
||||
/// Add min/max bands to plots
|
||||
#[arg(long, default_value_t = false)]
|
||||
min_max_band: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ProverKind {
|
||||
Native,
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProverKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProverKind::Native => write!(f, "Native"),
|
||||
ProverKind::Browser => write!(f, "Browser"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
if cli.csv.is_empty() {
|
||||
return Err("At least one CSV file must be provided".into());
|
||||
}
|
||||
|
||||
// Generate labels if not provided
|
||||
let labels: Vec<String> = if cli.labels.is_empty() {
|
||||
cli.csv
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("Dataset {}", i + 1))
|
||||
.collect()
|
||||
} else if cli.labels.len() != cli.csv.len() {
|
||||
return Err(format!(
|
||||
"Number of labels ({}) must match number of CSV files ({})",
|
||||
cli.labels.len(),
|
||||
cli.csv.len()
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
cli.labels.clone()
|
||||
};
|
||||
|
||||
// Load all CSVs and add dataset label
|
||||
let mut dfs = Vec::new();
|
||||
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
|
||||
let mut df = CsvReadOptions::default()
|
||||
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
|
||||
.finish()?;
|
||||
|
||||
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
|
||||
df.with_column(label_series)?;
|
||||
dfs.push(df);
|
||||
}
|
||||
|
||||
// Combine all dataframes
|
||||
let df = dfs
|
||||
.into_iter()
|
||||
.reduce(|acc, df| acc.vstack(&df).unwrap())
|
||||
.unwrap();
|
||||
let mut rdr = csv::Reader::from_path(&cli.csv)?;
|
||||
|
||||
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
|
||||
let groups = items.group;
|
||||
|
||||
for group in groups {
|
||||
// Determine which field varies in benches for this group
|
||||
let benches_in_group: Vec<_> = items
|
||||
.bench
|
||||
.iter()
|
||||
.filter(|b| b.group.as_deref() == Some(&group.name))
|
||||
.collect();
|
||||
// Prepare data for plotting.
|
||||
let all_data: Vec<Measurement> = rdr
|
||||
.deserialize::<Measurement>()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
if benches_in_group.is_empty() {
|
||||
continue;
|
||||
for group in groups {
|
||||
if group.protocol_latency.is_some() {
|
||||
let latency = group.protocol_latency.unwrap();
|
||||
plot_runtime_vs(
|
||||
&all_data,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
|
||||
"runtime_vs_bandwidth.html",
|
||||
"Bandwidth (Mbps)",
|
||||
)?;
|
||||
}
|
||||
|
||||
// Check which field has varying values
|
||||
let bandwidth_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].bandwidth != w[1].bandwidth);
|
||||
let latency_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
|
||||
let download_size_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].download_size != w[1].download_size);
|
||||
|
||||
if download_size_varies {
|
||||
let upload_size = group.upload_size.unwrap_or(1024);
|
||||
if group.bandwidth.is_some() {
|
||||
let bandwidth = group.bandwidth.unwrap();
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
&all_data,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"download_size",
|
||||
1.0 / 1024.0, // bytes to KB
|
||||
"Runtime vs Response Size",
|
||||
format!("{} bytes upload size", upload_size),
|
||||
"runtime_vs_download_size",
|
||||
"Response Size (KB)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
} else if bandwidth_varies {
|
||||
let latency = group.protocol_latency.unwrap_or(50);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"bandwidth",
|
||||
1.0 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency", latency),
|
||||
"runtime_vs_bandwidth",
|
||||
"Bandwidth (Mbps)",
|
||||
false, // legend on right
|
||||
)?;
|
||||
} else if latency_varies {
|
||||
let bandwidth = group.bandwidth.unwrap_or(1000);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"latency",
|
||||
1.0,
|
||||
|r| r.latency as f32,
|
||||
"Runtime vs Latency",
|
||||
format!("{} bps bandwidth", bandwidth),
|
||||
"runtime_vs_latency",
|
||||
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
|
||||
"runtime_vs_latency.html",
|
||||
"Latency (ms)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
}
|
||||
}
|
||||
@@ -152,52 +92,84 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct DataPoint {
|
||||
min: f32,
|
||||
mean: f32,
|
||||
max: f32,
|
||||
}
|
||||
|
||||
struct Points {
|
||||
preprocess: DataPoint,
|
||||
online: DataPoint,
|
||||
total: DataPoint,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn plot_runtime_vs(
|
||||
df: &DataFrame,
|
||||
labels: &[String],
|
||||
fn plot_runtime_vs<Fx>(
|
||||
all_data: &[Measurement],
|
||||
show_min_max: bool,
|
||||
group: &str,
|
||||
x_col: &str,
|
||||
x_scale: f32,
|
||||
x_value: Fx,
|
||||
title: &str,
|
||||
subtitle: String,
|
||||
output_file: &str,
|
||||
x_axis_label: &str,
|
||||
legend_left: bool,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
let stats_df = df
|
||||
.clone()
|
||||
.lazy()
|
||||
.filter(col("group").eq(lit(group)))
|
||||
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
|
||||
.with_columns([
|
||||
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
|
||||
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
|
||||
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
|
||||
])
|
||||
.group_by([col("x"), col("dataset_label")])
|
||||
.agg([
|
||||
col("preprocess").min().alias("preprocess_min"),
|
||||
col("preprocess").mean().alias("preprocess_mean"),
|
||||
col("preprocess").max().alias("preprocess_max"),
|
||||
col("online").min().alias("online_min"),
|
||||
col("online").mean().alias("online_mean"),
|
||||
col("online").max().alias("online_max"),
|
||||
col("total").min().alias("total_min"),
|
||||
col("total").mean().alias("total_mean"),
|
||||
col("total").max().alias("total_max"),
|
||||
])
|
||||
.sort(["dataset_label", "x"], Default::default())
|
||||
.collect()?;
|
||||
|
||||
// Build legend entries
|
||||
let mut legend_data = Vec::new();
|
||||
for label in labels {
|
||||
legend_data.push(format!("Total Mean ({})", label));
|
||||
legend_data.push(format!("Online Mean ({})", label));
|
||||
) -> Result<Chart, Box<dyn std::error::Error>>
|
||||
where
|
||||
Fx: Fn(&Measurement) -> f32,
|
||||
{
|
||||
fn data_point(values: &[f32]) -> DataPoint {
|
||||
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
|
||||
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
|
||||
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
|
||||
DataPoint { min, mean, max }
|
||||
}
|
||||
|
||||
let stats: Vec<(f32, Points)> = all_data
|
||||
.iter()
|
||||
.filter(|r| r.group.as_deref() == Some(group))
|
||||
.map(|r| {
|
||||
(
|
||||
x_value(r),
|
||||
r.time_preprocess as f32 / 1000.0, // ms to s
|
||||
r.time_online as f32 / 1000.0,
|
||||
r.time_total as f32 / 1000.0,
|
||||
)
|
||||
})
|
||||
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
|
||||
.chunk_by(|entry| entry.0)
|
||||
.into_iter()
|
||||
.map(|(x, group)| {
|
||||
let group_vec: Vec<_> = group.collect();
|
||||
let preprocess = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, t, _, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let online = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, t, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let total = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, _, t)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
(
|
||||
x,
|
||||
Points {
|
||||
preprocess,
|
||||
online,
|
||||
total,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut chart = Chart::new()
|
||||
.title(
|
||||
Title::new()
|
||||
@@ -207,6 +179,14 @@ fn plot_runtime_vs(
|
||||
.subtext_style(TextStyle::new().font_size(16)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(
|
||||
Legend::new()
|
||||
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
|
||||
.top("80")
|
||||
.right("110")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.name(x_axis_label)
|
||||
@@ -225,156 +205,73 @@ fn plot_runtime_vs(
|
||||
.name_text_style(TextStyle::new().font_size(21)),
|
||||
);
|
||||
|
||||
// Add legend with conditional positioning
|
||||
let legend = Legend::new()
|
||||
.data(legend_data)
|
||||
.top("80")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10);
|
||||
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
|
||||
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
|
||||
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
|
||||
|
||||
let legend = if legend_left {
|
||||
legend.left("110")
|
||||
} else {
|
||||
legend.right("110")
|
||||
};
|
||||
|
||||
chart = chart.legend(legend);
|
||||
|
||||
// Define colors for each dataset
|
||||
let colors = vec![
|
||||
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
|
||||
];
|
||||
|
||||
for (idx, label) in labels.iter().enumerate() {
|
||||
let color = colors.get(idx % colors.len()).unwrap();
|
||||
|
||||
// Total time - solid line
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Mean ({})", label),
|
||||
"total_mean",
|
||||
false,
|
||||
color,
|
||||
)?;
|
||||
|
||||
// Online time - dashed line (same color as total)
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Online Mean ({})", label),
|
||||
"online_mean",
|
||||
true,
|
||||
color,
|
||||
)?;
|
||||
|
||||
if show_min_max {
|
||||
chart = add_dataset_min_max_band(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Min/Max ({})", label),
|
||||
"total",
|
||||
color,
|
||||
)?;
|
||||
}
|
||||
if show_min_max {
|
||||
chart = add_min_max_band(
|
||||
chart,
|
||||
&stats,
|
||||
"Preprocess Min/Max",
|
||||
|p| &p.preprocess,
|
||||
"#ccc",
|
||||
);
|
||||
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
|
||||
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
|
||||
}
|
||||
// Save the chart as HTML file (no theme)
|
||||
// Save the chart as HTML file.
|
||||
HtmlRenderer::new(title, 1000, 800)
|
||||
.save(&chart, &format!("{}.html", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with default theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Default)
|
||||
.save(&chart, &format!("{}.svg", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with dark theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Dark)
|
||||
.save(&chart, &format!("{}_dark.svg", output_file))
|
||||
.theme(THEME)
|
||||
.save(&chart, output_file)
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
|
||||
fn add_dataset_series(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
series_name: &str,
|
||||
col_name: &str,
|
||||
dashed: bool,
|
||||
color: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let y = filtered.column(col_name)?.f32()?;
|
||||
|
||||
let data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(y.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let mut line = Line::new()
|
||||
.name(series_name)
|
||||
.data(data)
|
||||
.symbol_size(6)
|
||||
.item_style(ItemStyle::new().color(color));
|
||||
|
||||
let mut line_style = LineStyle::new();
|
||||
if dashed {
|
||||
line_style = line_style.type_(LineStyleType::Dashed);
|
||||
}
|
||||
line = line.line_style(line_style.color(color));
|
||||
|
||||
Ok(chart.clone().series(line))
|
||||
}
|
||||
|
||||
fn add_dataset_min_max_band(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
fn add_mean_series(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
name: &str,
|
||||
col_prefix: &str,
|
||||
color: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
|
||||
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
|
||||
|
||||
let max_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(max_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let min_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(min_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
|
||||
|
||||
Ok(chart.clone().series(
|
||||
extract: impl Fn(&Points) -> f32,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(data)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points)])
|
||||
.collect(),
|
||||
)
|
||||
.symbol_size(6),
|
||||
)
|
||||
}
|
||||
|
||||
fn add_min_max_band(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
name: &str,
|
||||
extract: impl Fn(&Points) -> &DataPoint,
|
||||
color: &str,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points).max])
|
||||
.chain(
|
||||
stats
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|(x, points)| vec![*x, extract(points).min]),
|
||||
)
|
||||
.collect(),
|
||||
)
|
||||
.show_symbol(false)
|
||||
.line_style(LineStyle::new().opacity(0.0))
|
||||
.area_style(AreaStyle::new().opacity(0.3).color(color)),
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,25 +0,0 @@
|
||||
#### Bandwidth ####
|
||||
|
||||
[[group]]
|
||||
name = "bandwidth"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 10
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 50
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 100
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 250
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 1000
|
||||
@@ -1,37 +0,0 @@
|
||||
[[group]]
|
||||
name = "download_size"
|
||||
protocol_latency = 10
|
||||
bandwidth = 200
|
||||
upload-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 1024
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 4096
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 8192
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 16384
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 32768
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 65536
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 131072
|
||||
@@ -1,25 +0,0 @@
|
||||
#### Latency ####
|
||||
|
||||
[[group]]
|
||||
name = "latency"
|
||||
bandwidth = 1000
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 10
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 50
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 100
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 200
|
||||
@@ -24,7 +24,7 @@ use std::{
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::{debug, debug_span, trace, warn, Instrument};
|
||||
use tracing::{debug, debug_span, error, trace, warn, Instrument};
|
||||
|
||||
use tls_client::ClientConnection;
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ web-spawn = { workspace = true, optional = true }
|
||||
mpz-circuits = { workspace = true, features = ["aes"] }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-predicate = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-garble-core = { workspace = true }
|
||||
mpz-hash = { workspace = true }
|
||||
|
||||
@@ -49,6 +49,13 @@ impl ProverError {
|
||||
{
|
||||
Self::new(ErrorKind::Commit, source)
|
||||
}
|
||||
|
||||
pub(crate) fn predicate<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::Predicate, source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -58,6 +65,7 @@ enum ErrorKind {
|
||||
Zk,
|
||||
Config,
|
||||
Commit,
|
||||
Predicate,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProverError {
|
||||
@@ -70,6 +78,7 @@ impl fmt::Display for ProverError {
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::Predicate => f.write_str("predicate error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
|
||||
@@ -20,6 +20,7 @@ use crate::{
|
||||
encoding::{self, MacStore},
|
||||
hash::prove_hash,
|
||||
},
|
||||
predicate::prove_predicates,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -54,6 +55,20 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
});
|
||||
}
|
||||
|
||||
// Build predicate ranges from config
|
||||
let (mut predicate_sent, mut predicate_recv) = (RangeSet::default(), RangeSet::default());
|
||||
for predicate in config.predicates() {
|
||||
let indices: RangeSet<usize> = predicate
|
||||
.indices()
|
||||
.into_iter()
|
||||
.map(|idx| idx..idx + 1)
|
||||
.collect();
|
||||
match predicate.direction() {
|
||||
Direction::Sent => predicate_sent.union_mut(&indices),
|
||||
Direction::Received => predicate_recv.union_mut(&indices),
|
||||
}
|
||||
}
|
||||
|
||||
let transcript_refs = TranscriptRefs {
|
||||
sent: prove_plaintext(
|
||||
vm,
|
||||
@@ -66,6 +81,7 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
&reveal_sent,
|
||||
&commit_sent,
|
||||
&predicate_sent,
|
||||
)
|
||||
.map_err(ProverError::commit)?,
|
||||
recv: prove_plaintext(
|
||||
@@ -79,6 +95,7 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
&reveal_recv,
|
||||
&commit_recv,
|
||||
&predicate_recv,
|
||||
)
|
||||
.map_err(ProverError::commit)?,
|
||||
};
|
||||
@@ -100,6 +117,12 @@ pub(crate) async fn prove<T: Vm<Binary> + MacStore + Send + Sync>(
|
||||
None
|
||||
};
|
||||
|
||||
// Prove predicates over transcript data
|
||||
if !config.predicates().is_empty() {
|
||||
prove_predicates(vm, &transcript_refs, config.predicates())
|
||||
.map_err(ProverError::predicate)?;
|
||||
}
|
||||
|
||||
vm.execute_all(ctx).await.map_err(ProverError::zk)?;
|
||||
|
||||
if let Some(commit_config) = config.transcript_commit()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub(crate) mod auth;
|
||||
pub(crate) mod commit;
|
||||
pub(crate) mod predicate;
|
||||
|
||||
use mpz_memory_core::{Vector, binary::U8};
|
||||
|
||||
|
||||
@@ -25,14 +25,15 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
records: impl IntoIterator<Item = &'a Record>,
|
||||
reveal: &RangeSet<usize>,
|
||||
commit: &RangeSet<usize>,
|
||||
predicate: &RangeSet<usize>,
|
||||
) -> Result<ReferenceMap, PlaintextAuthError> {
|
||||
let is_reveal_all = reveal == (0..plaintext.len());
|
||||
let is_reveal_all = reveal == (0..plaintext.len()) && predicate.is_empty();
|
||||
|
||||
let alloc_ranges = if is_reveal_all {
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal).into_set()
|
||||
commit.union(reveal).union(predicate).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -49,7 +50,8 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
} else {
|
||||
let private = commit.difference(reveal).into_set();
|
||||
// Private ranges: committed but not revealed, plus predicate ranges
|
||||
let private = commit.difference(reveal).union(predicate).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.expect("all ranges are allocated")
|
||||
@@ -91,14 +93,15 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
records: impl IntoIterator<Item = &'a Record>,
|
||||
reveal: &RangeSet<usize>,
|
||||
commit: &RangeSet<usize>,
|
||||
predicate: &RangeSet<usize>,
|
||||
) -> Result<(ReferenceMap, PlaintextProof<'a>), PlaintextAuthError> {
|
||||
let is_reveal_all = reveal == (0..plaintext.len());
|
||||
let is_reveal_all = reveal == (0..plaintext.len()) && predicate.is_empty();
|
||||
|
||||
let alloc_ranges = if is_reveal_all {
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal).into_set()
|
||||
commit.union(reveal).union(predicate).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -123,9 +126,10 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
ciphertext,
|
||||
})
|
||||
} else {
|
||||
let private = commit.difference(reveal).into_set();
|
||||
// Blind ranges: committed but not revealed, plus predicate ranges
|
||||
let blind = commit.difference(reveal).union(predicate).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.index(&blind)
|
||||
.expect("all ranges are allocated")
|
||||
.iter()
|
||||
{
|
||||
|
||||
175
crates/tlsn/src/transcript_internal/predicate.rs
Normal file
175
crates/tlsn/src/transcript_internal/predicate.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
//! Predicate proving and verification over transcript data.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use mpz_circuits::Circuit;
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
DecodeFutureTyped, MemoryExt,
|
||||
binary::{Binary, Bool},
|
||||
};
|
||||
use mpz_predicate::{Pred, compiler::Compiler};
|
||||
use mpz_vm_core::{Call, CallableExt, Vm};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{config::prove::PredicateConfig, transcript::Direction};
|
||||
|
||||
use super::{ReferenceMap, TranscriptRefs};
|
||||
|
||||
/// Error during predicate proving/verification.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum PredicateError {
|
||||
/// Indices not found in transcript references.
|
||||
#[error("predicate indices {0:?} not found in transcript references")]
|
||||
IndicesNotFound(RangeSet<usize>),
|
||||
/// VM error.
|
||||
#[error("VM error: {0}")]
|
||||
Vm(#[from] mpz_vm_core::VmError),
|
||||
/// Circuit call error.
|
||||
#[error("circuit call error: {0}")]
|
||||
Call(#[from] mpz_vm_core::CallError),
|
||||
/// Decode error.
|
||||
#[error("decode error: {0}")]
|
||||
Decode(#[from] mpz_memory_core::DecodeError),
|
||||
/// Missing decoding.
|
||||
#[error("missing decoding")]
|
||||
MissingDecoding,
|
||||
/// Predicate not satisfied.
|
||||
#[error("predicate evaluated to false")]
|
||||
PredicateNotSatisfied,
|
||||
}
|
||||
|
||||
/// Converts a slice of indices to a RangeSet (each index becomes a single-byte
|
||||
/// range).
|
||||
fn indices_to_rangeset(indices: &[usize]) -> RangeSet<usize> {
|
||||
indices.iter().map(|&idx| idx..idx + 1).collect()
|
||||
}
|
||||
|
||||
/// Proves predicates over transcript data (prover side).
|
||||
///
|
||||
/// Each predicate is compiled to a circuit and executed with the corresponding
|
||||
/// transcript bytes as input. The circuit outputs a single bit that must be
|
||||
/// true.
|
||||
pub(crate) fn prove_predicates<T: Vm<Binary>>(
|
||||
vm: &mut T,
|
||||
transcript_refs: &TranscriptRefs,
|
||||
predicates: &[PredicateConfig],
|
||||
) -> Result<(), PredicateError> {
|
||||
let mut compiler = Compiler::new();
|
||||
|
||||
for predicate in predicates {
|
||||
let refs = match predicate.direction() {
|
||||
Direction::Sent => &transcript_refs.sent,
|
||||
Direction::Received => &transcript_refs.recv,
|
||||
};
|
||||
|
||||
// Compile predicate to circuit
|
||||
let circuit = compiler.compile(predicate.predicate());
|
||||
|
||||
// Get indices from the predicate and convert to RangeSet
|
||||
let indices = indices_to_rangeset(&predicate.indices());
|
||||
|
||||
// Prover doesn't need to verify output - they know their data satisfies the predicate
|
||||
let _ = execute_predicate(vm, refs, &indices, &circuit)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Proof that predicates were satisfied.
|
||||
///
|
||||
/// Must be verified after `vm.execute_all()` completes.
|
||||
#[must_use]
|
||||
pub(crate) struct PredicateProof {
|
||||
/// Decode futures for each predicate output.
|
||||
outputs: Vec<DecodeFutureTyped<BitVec, bool>>,
|
||||
}
|
||||
|
||||
impl PredicateProof {
|
||||
/// Verifies that all predicates evaluated to true.
|
||||
///
|
||||
/// Must be called after `vm.execute_all()` completes.
|
||||
pub(crate) fn verify(self) -> Result<(), PredicateError> {
|
||||
for mut output in self.outputs {
|
||||
let result = output
|
||||
.try_recv()
|
||||
.map_err(PredicateError::Decode)?
|
||||
.ok_or(PredicateError::MissingDecoding)?;
|
||||
|
||||
if !result {
|
||||
return Err(PredicateError::PredicateNotSatisfied);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Verifies predicates over transcript data (verifier side).
|
||||
///
|
||||
/// The verifier must provide the same predicates that the prover used,
|
||||
/// looked up by predicate name from out-of-band agreement.
|
||||
///
|
||||
/// Returns a [`PredicateProof`] that must be verified after `vm.execute_all()`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - The zkVM.
|
||||
/// * `transcript_refs` - References to transcript data in the VM.
|
||||
/// * `predicates` - Iterator of (direction, indices, predicate) tuples.
|
||||
pub(crate) fn verify_predicates<T: Vm<Binary>>(
|
||||
vm: &mut T,
|
||||
transcript_refs: &TranscriptRefs,
|
||||
predicates: impl IntoIterator<Item = (Direction, RangeSet<usize>, Pred)>,
|
||||
) -> Result<PredicateProof, PredicateError> {
|
||||
let mut compiler = Compiler::new();
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
for (direction, indices, predicate) in predicates {
|
||||
let refs = match direction {
|
||||
Direction::Sent => &transcript_refs.sent,
|
||||
Direction::Received => &transcript_refs.recv,
|
||||
};
|
||||
|
||||
// Compile predicate to circuit
|
||||
let circuit = compiler.compile(&predicate);
|
||||
|
||||
let output_fut = execute_predicate(vm, refs, &indices, &circuit)?;
|
||||
outputs.push(output_fut);
|
||||
}
|
||||
|
||||
Ok(PredicateProof { outputs })
|
||||
}
|
||||
|
||||
/// Executes a predicate circuit with transcript bytes as input.
|
||||
///
|
||||
/// Returns a decode future for the circuit output.
|
||||
fn execute_predicate<T: Vm<Binary>>(
|
||||
vm: &mut T,
|
||||
refs: &ReferenceMap,
|
||||
indices: &RangeSet<usize>,
|
||||
circuit: &Circuit,
|
||||
) -> Result<DecodeFutureTyped<BitVec, bool>, PredicateError> {
|
||||
// Get the transcript bytes for the predicate indices
|
||||
let indexed_refs = refs
|
||||
.index(indices)
|
||||
.ok_or_else(|| PredicateError::IndicesNotFound(indices.clone()))?;
|
||||
|
||||
// Build the circuit call with transcript bytes as inputs
|
||||
let circuit = Arc::new(circuit.clone());
|
||||
let mut call_builder = Call::builder(circuit);
|
||||
|
||||
// Add each byte in the range as an input to the circuit
|
||||
// The predicate circuit expects bytes in order, so we iterate through
|
||||
// the indexed refs which maintains ordering
|
||||
for (_range, vector) in indexed_refs.iter() {
|
||||
call_builder = call_builder.arg(*vector);
|
||||
}
|
||||
|
||||
let call = call_builder.build()?;
|
||||
|
||||
// Execute the circuit - output is a single bit (true/false)
|
||||
// Both parties must call decode() on the output to reveal it
|
||||
let output: Bool = vm.call(call)?;
|
||||
|
||||
// Return decode future - caller must verify output == true after execute_all
|
||||
Ok(vm.decode(output)?)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use std::sync::Arc;
|
||||
|
||||
pub use error::VerifierError;
|
||||
pub use tlsn_core::{VerifierOutput, webpki::ServerCertVerifier};
|
||||
pub use verify::PredicateResolver;
|
||||
|
||||
use crate::{
|
||||
Role,
|
||||
@@ -323,8 +324,24 @@ impl Verifier<state::Verify> {
|
||||
}
|
||||
|
||||
/// Accepts the proving request.
|
||||
///
|
||||
/// Note: If the prover requests predicate verification, use
|
||||
/// [`accept_with_predicates`](Self::accept_with_predicates) instead.
|
||||
pub async fn accept(
|
||||
self,
|
||||
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
|
||||
self.accept_with_predicates(None).await
|
||||
}
|
||||
|
||||
/// Accepts the proving request with predicate verification support.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `predicate_resolver` - A function that resolves predicate names to
|
||||
/// circuits. Required if the prover requests any predicates.
|
||||
pub async fn accept_with_predicates(
|
||||
self,
|
||||
predicate_resolver: Option<&verify::PredicateResolver>,
|
||||
) -> Result<(VerifierOutput, Verifier<state::Committed>), VerifierError> {
|
||||
let state::Verify {
|
||||
mux_ctrl,
|
||||
@@ -353,6 +370,7 @@ impl Verifier<state::Verify> {
|
||||
request,
|
||||
handshake,
|
||||
transcript,
|
||||
predicate_resolver,
|
||||
))
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -49,6 +49,13 @@ impl VerifierError {
|
||||
{
|
||||
Self::new(ErrorKind::Verify, source)
|
||||
}
|
||||
|
||||
pub(crate) fn predicate<E>(source: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self::new(ErrorKind::Predicate, source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -59,6 +66,7 @@ enum ErrorKind {
|
||||
Zk,
|
||||
Commit,
|
||||
Verify,
|
||||
Predicate,
|
||||
}
|
||||
|
||||
impl fmt::Display for VerifierError {
|
||||
@@ -72,6 +80,7 @@ impl fmt::Display for VerifierError {
|
||||
ErrorKind::Zk => f.write_str("zk error")?,
|
||||
ErrorKind::Commit => f.write_str("commit error")?,
|
||||
ErrorKind::Verify => f.write_str("verification error")?,
|
||||
ErrorKind::Predicate => f.write_str("predicate error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_predicate::Pred;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
@@ -21,10 +22,24 @@ use crate::{
|
||||
encoding::{self, KeyStore},
|
||||
hash::verify_hash,
|
||||
},
|
||||
predicate::verify_predicates,
|
||||
},
|
||||
verifier::VerifierError,
|
||||
};
|
||||
|
||||
/// A function that resolves predicate names to predicates.
|
||||
///
|
||||
/// The verifier must provide this to look up predicates by name,
|
||||
/// based on out-of-band agreement with the prover.
|
||||
///
|
||||
/// The function receives:
|
||||
/// - The predicate name
|
||||
/// - The byte indices the predicate operates on (from the prover's request)
|
||||
///
|
||||
/// The verifier should validate that the indices make sense for the predicate
|
||||
/// and return the appropriate predicate built with those indices.
|
||||
pub type PredicateResolver = dyn Fn(&str, &RangeSet<usize>) -> Option<Pred> + Send + Sync;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
ctx: &mut Context,
|
||||
@@ -35,6 +50,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
request: ProveRequest,
|
||||
handshake: Option<(ServerName, HandshakeData)>,
|
||||
transcript: Option<PartialTranscript>,
|
||||
predicate_resolver: Option<&PredicateResolver>,
|
||||
) -> Result<VerifierOutput, VerifierError> {
|
||||
let ciphertext_sent = collect_ciphertext(tls_transcript.sent());
|
||||
let ciphertext_recv = collect_ciphertext(tls_transcript.recv());
|
||||
@@ -101,6 +117,15 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
}
|
||||
}
|
||||
|
||||
// Build predicate ranges from request
|
||||
let (mut predicate_sent, mut predicate_recv) = (RangeSet::default(), RangeSet::default());
|
||||
for predicate_req in request.predicates() {
|
||||
match predicate_req.direction() {
|
||||
Direction::Sent => predicate_sent.union_mut(predicate_req.indices()),
|
||||
Direction::Received => predicate_recv.union_mut(predicate_req.indices()),
|
||||
}
|
||||
}
|
||||
|
||||
let (sent_refs, sent_proof) = verify_plaintext(
|
||||
vm,
|
||||
keys.client_write_key,
|
||||
@@ -113,6 +138,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
transcript.sent_authed(),
|
||||
&commit_sent,
|
||||
&predicate_sent,
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
let (recv_refs, recv_proof) = verify_plaintext(
|
||||
@@ -127,6 +153,7 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
.filter(|record| record.typ == ContentType::ApplicationData),
|
||||
transcript.received_authed(),
|
||||
&commit_recv,
|
||||
&predicate_recv,
|
||||
)
|
||||
.map_err(VerifierError::zk)?;
|
||||
|
||||
@@ -146,11 +173,38 @@ pub(crate) async fn verify<T: Vm<Binary> + KeyStore + Send + Sync>(
|
||||
);
|
||||
}
|
||||
|
||||
// Verify predicates if any were requested
|
||||
let predicate_proof = if !request.predicates().is_empty() {
|
||||
let resolver = predicate_resolver.ok_or_else(|| {
|
||||
VerifierError::predicate("predicates requested but no resolver provided")
|
||||
})?;
|
||||
|
||||
let predicates = request
|
||||
.predicates()
|
||||
.iter()
|
||||
.map(|req| {
|
||||
let predicate = resolver(req.name(), req.indices()).ok_or_else(|| {
|
||||
VerifierError::predicate(format!("unknown predicate: {}", req.name()))
|
||||
})?;
|
||||
Ok((req.direction(), req.indices().clone(), predicate))
|
||||
})
|
||||
.collect::<Result<Vec<_>, VerifierError>>()?;
|
||||
|
||||
Some(verify_predicates(vm, &transcript_refs, predicates).map_err(VerifierError::predicate)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
vm.execute_all(ctx).await.map_err(VerifierError::zk)?;
|
||||
|
||||
sent_proof.verify().map_err(VerifierError::verify)?;
|
||||
recv_proof.verify().map_err(VerifierError::verify)?;
|
||||
|
||||
// Verify predicate outputs after ZK execution
|
||||
if let Some(proof) = predicate_proof {
|
||||
proof.verify().map_err(VerifierError::predicate)?;
|
||||
}
|
||||
|
||||
let mut encoder_secret = None;
|
||||
if let Some(commit_config) = request.transcript_commit()
|
||||
&& let Some((sent, recv)) = commit_config.encoding()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use mpz_predicate::{Pred, eq};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn::{
|
||||
config::{
|
||||
@@ -231,3 +232,184 @@ async fn verifier<T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
// Predicate name for testing
|
||||
const TEST_PREDICATE: &str = "test_first_byte";
|
||||
|
||||
/// Test that a correct predicate passes verification.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore]
|
||||
async fn test_predicate_passes() {
|
||||
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
|
||||
|
||||
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/' (in "HTTP/1.1")
|
||||
// Using index 10 to avoid overlap with revealed range (0..10)
|
||||
// Verifier uses the same predicate - should pass
|
||||
let prover_predicate = eq(10, b'/');
|
||||
|
||||
let (prover_result, verifier_result) = tokio::join!(
|
||||
prover_with_predicate(socket_0, prover_predicate),
|
||||
verifier_with_predicate(socket_1, || eq(10, b'/'))
|
||||
);
|
||||
|
||||
prover_result.expect("prover should succeed");
|
||||
verifier_result.expect("verifier should succeed with correct predicate");
|
||||
}
|
||||
|
||||
/// Test that a wrong predicate is rejected by the verifier.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore]
|
||||
async fn test_wrong_predicate_rejected() {
|
||||
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
|
||||
|
||||
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/'
|
||||
// Verifier uses a DIFFERENT predicate that checks for 'X' - should fail
|
||||
let prover_predicate = eq(10, b'/');
|
||||
|
||||
let (prover_result, verifier_result) = tokio::join!(
|
||||
prover_with_predicate(socket_0, prover_predicate),
|
||||
verifier_with_predicate(socket_1, || eq(10, b'X'))
|
||||
);
|
||||
|
||||
// Prover may succeed or fail depending on when verifier rejects
|
||||
let _ = prover_result;
|
||||
|
||||
// Verifier should fail because predicate evaluates to false
|
||||
assert!(
|
||||
verifier_result.is_err(),
|
||||
"verifier should reject wrong predicate"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that prover can't prove a predicate their data doesn't satisfy.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[ignore]
|
||||
async fn test_unsatisfied_predicate_rejected() {
|
||||
let (socket_0, socket_1) = tokio::io::duplex(2 << 23);
|
||||
|
||||
// Request is "GET / HTTP/1.1\r\n..." - index 10 is '/'
|
||||
// Both parties use eq(10, b'X') but prover's data has '/' at index 10
|
||||
// This tests that a prover can't cheat - the predicate must actually be satisfied
|
||||
let prover_predicate = eq(10, b'X');
|
||||
|
||||
let (prover_result, verifier_result) = tokio::join!(
|
||||
prover_with_predicate(socket_0, prover_predicate),
|
||||
verifier_with_predicate(socket_1, || eq(10, b'X'))
|
||||
);
|
||||
|
||||
// Prover may succeed or fail depending on when verifier rejects
|
||||
let _ = prover_result;
|
||||
|
||||
// Verifier should fail because prover's data doesn't satisfy the predicate
|
||||
assert!(
|
||||
verifier_result.is_err(),
|
||||
"verifier should reject unsatisfied predicate"
|
||||
);
|
||||
}
|
||||
|
||||
#[instrument(skip(verifier_socket, predicate))]
|
||||
async fn prover_with_predicate<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
verifier_socket: T,
|
||||
predicate: Pred,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let (client_socket, server_socket) = tokio::io::duplex(2 << 16);
|
||||
|
||||
let server_task = tokio::spawn(bind(server_socket.compat()));
|
||||
|
||||
let prover = Prover::new(ProverConfig::builder().build()?)
|
||||
.commit(
|
||||
TlsCommitConfig::builder()
|
||||
.protocol(
|
||||
MpcTlsConfig::builder()
|
||||
.max_sent_data(MAX_SENT_DATA)
|
||||
.max_sent_records(MAX_SENT_RECORDS)
|
||||
.max_recv_data(MAX_RECV_DATA)
|
||||
.max_recv_records_online(MAX_RECV_RECORDS)
|
||||
.build()?,
|
||||
)
|
||||
.build()?,
|
||||
verifier_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (mut tls_connection, prover_fut) = prover
|
||||
.connect(
|
||||
TlsClientConfig::builder()
|
||||
.server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?))
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
client_socket.compat(),
|
||||
)
|
||||
.await?;
|
||||
let prover_task = tokio::spawn(prover_fut);
|
||||
|
||||
tls_connection
|
||||
.write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
.await?;
|
||||
tls_connection.close().await?;
|
||||
|
||||
let mut response = vec![0u8; 1024];
|
||||
tls_connection.read_to_end(&mut response).await?;
|
||||
|
||||
let _ = server_task.await?;
|
||||
|
||||
let mut prover = prover_task.await??;
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
builder.server_identity();
|
||||
builder.reveal_sent(&(0..10))?;
|
||||
builder.reveal_recv(&(0..10))?;
|
||||
builder.predicate(TEST_PREDICATE, Direction::Sent, predicate)?;
|
||||
|
||||
let config = builder.build()?;
|
||||
prover.prove(&config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn verifier_with_predicate<T, F>(
|
||||
socket: T,
|
||||
make_predicate: F,
|
||||
) -> Result<VerifierOutput, Box<dyn std::error::Error + Send + Sync>>
|
||||
where
|
||||
T: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static,
|
||||
F: Fn() -> Pred + Send + Sync + 'static,
|
||||
{
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
.root_store(RootCertStore {
|
||||
roots: vec![CertificateDer(CA_CERT_DER.to_vec())],
|
||||
})
|
||||
.build()?,
|
||||
);
|
||||
|
||||
let verifier = verifier
|
||||
.commit(socket.compat())
|
||||
.await?
|
||||
.accept()
|
||||
.await?
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
let verifier = verifier.verify().await?;
|
||||
|
||||
// Resolver that builds the predicate fresh (Pred uses Rc, so can't be shared)
|
||||
let predicate_resolver = move |name: &str, _indices: &RangeSet<usize>| -> Option<Pred> {
|
||||
if name == TEST_PREDICATE {
|
||||
Some(make_predicate())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (output, verifier) = verifier
|
||||
.accept_with_predicates(Some(&predicate_resolver))
|
||||
.await?;
|
||||
|
||||
verifier.close().await?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user