Compare commits

..

2 Commits

Author SHA1 Message Date
dan
a34355ef77 fix test 2025-11-19 09:17:38 +02:00
dan
bf6a23a61e convert pest ast to predicate ast 2025-11-19 08:54:21 +02:00
47 changed files with 2610 additions and 3032 deletions

View File

@@ -21,8 +21,7 @@ env:
# - https://github.com/privacy-ethereum/mpz/issues/178
# 32 seems to be big enough for the foreseeable future
RAYON_NUM_THREADS: 32
RUST_VERSION: 1.92.0
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RUST_VERSION: 1.90.0
jobs:
clippy:

2523
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -66,27 +66,26 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
rangeset = { version = "0.4" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
uid-mux = { version = "0.2" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
aead = { version = "0.4" }
aes = { version = "0.8" }

View File

@@ -15,7 +15,7 @@ use mpz_vm_core::{
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
Call, Callable, Execute, Vm, VmError,
};
use rangeset::{ops::Set, set::RangeSet};
use rangeset::{Difference, RangeSet, UnionMut};
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
type Error = DeapError;
@@ -210,12 +210,10 @@ where
}
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let slice_range = slice.to_range();
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower {
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
@@ -268,7 +266,7 @@ where
mpc.mark_private_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
}
@@ -284,7 +282,7 @@ where
mpc.mark_blind_raw(slice)?;
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(slice.to_range());
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_vm_core::{memory::Slice, VmError};
use rangeset::ops::Set;
use rangeset::Subset;
/// A mapping between the memories of the MPC and ZK VMs.
#[derive(Debug, Default)]

View File

@@ -27,6 +27,12 @@ tlsn-data-fixtures = { workspace = true, optional = true }
tlsn-tls-core = { workspace = true, features = ["serde"] }
tlsn-utils = { workspace = true }
rangeset = { workspace = true, features = ["serde"] }
mpz-circuits = { workspace = true }
pest = "*"
pest_derive = "*"
pest_meta = "*"
aead = { workspace = true, features = ["alloc"], optional = true }
aes-gcm = { workspace = true, optional = true }
@@ -59,7 +65,5 @@ generic-array = { workspace = true }
bincode = { workspace = true }
hex = { workspace = true }
rstest = { workspace = true }
tlsn-core = { workspace = true, features = ["fixtures"] }
tlsn-attestation = { workspace = true, features = ["fixtures"] }
tlsn-data-fixtures = { workspace = true }
webpki-root-certs = { workspace = true }

View File

@@ -1,6 +1,6 @@
//! Proving configuration.
use rangeset::set::{RangeSet, ToRangeSet};
use rangeset::{RangeSet, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};

View File

@@ -185,17 +185,22 @@ impl MpcTlsConfigBuilder {
///
/// Provides optimization options to adapt the protocol to different network
/// situations.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NetworkSetting {
/// Reduces network round-trips at the expense of consuming more network
/// bandwidth.
Bandwidth,
/// Reduces network bandwidth utilization at the expense of more network
/// round-trips.
#[default]
Latency,
}
impl Default for NetworkSetting {
fn default() -> Self {
Self::Latency
}
}
/// Error for [`MpcTlsConfig`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]

View File

@@ -1,11 +1,11 @@
use rangeset::set::RangeSet;
use rangeset::RangeSet;
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("{")?;
for range in self.0.iter() {
for range in self.0.iter_ranges() {
write!(f, "{}..{}", range.start, range.end)?;
if range.end < self.0.end().unwrap_or(0) {
f.write_str(", ")?;

View File

@@ -2,7 +2,6 @@
use aead::Payload as AeadPayload;
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
#[allow(deprecated)]
use generic_array::GenericArray;
use rand::{rngs::StdRng, Rng, SeedableRng};
use tls_core::msgs::{
@@ -181,7 +180,6 @@ fn aes_gcm_encrypt(
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&iv);
nonce[4..].copy_from_slice(&explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();

159
crates/core/src/grammar.rs Normal file
View File

@@ -0,0 +1,159 @@
use crate::predicates::Pred;
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pest_derive::Parser;
#[derive(Parser)]
#[grammar = "expr.pest"]
struct ExprParser;
fn parse_expr(input: &str) -> Result<Pred, pest::error::Error<Rule>> {
let mut pairs = ExprParser::parse(Rule::expr, input)?;
Ok(build_expr(pairs.next().unwrap()))
}
fn build_expr(pair: Pair<Rule>) -> Pred {
match pair.as_rule() {
Rule::expr | Rule::or_expr => build_left_assoc(pair.into_inner(), Rule::and_expr, Pred::Or),
Rule::and_expr => build_left_assoc(pair.into_inner(), Rule::not_expr, Pred::And),
Rule::not_expr => {
// NOT* cmp
let mut inner = pair.into_inner(); // possibly multiple NOT then a cmp
// Count NOTs, then parse cmp
let mut not_count = 0;
let mut rest = Vec::new();
for p in inner {
match p.as_rule() {
Rule::NOT => not_count += 1,
_ => {
rest.push(p);
}
}
}
let mut node = build_cmp(rest.into_iter().next().expect("cmp missing"));
if not_count % 2 == 1 {
node = Pred::Not(Box::new(node));
}
node
}
Rule::cmp => build_cmp(pair),
Rule::primary => build_expr(pair.into_inner().next().unwrap()),
Rule::paren => build_expr(pair.into_inner().next().unwrap()),
_ => unreachable!("unexpected rule: {:?}", pair.as_rule()),
}
}
fn build_left_assoc(
mut inner: Pairs<Rule>,
unit_rule: Rule,
mk_node: impl Fn(Vec<Pred>) -> Pred,
) -> Pred {
// pattern: unit (OP unit)*
let mut nodes = Vec::new();
// First unit
if let Some(first) = inner.next() {
assert_eq!(first.as_rule(), unit_rule);
nodes.push(build_expr(first));
}
// Remaining are: OP unit pairs; we only collect the units and wrap later.
while let Some(next) = inner.next() {
// next is the operator token pair (AND/OR), skip it
// then the unit:
if let Some(unit) = inner.next() {
assert_eq!(unit.as_rule(), unit_rule);
nodes.push(build_expr(unit));
}
}
if nodes.len() == 1 {
nodes.pop().unwrap()
} else {
mk_node(nodes)
}
}
fn build_cmp(pair: Pair<Rule>) -> Pred {
// cmp: primary (cmp_op primary)?
let mut inner = pair.into_inner();
let lhs = inner.next().unwrap();
let lhs_term = parse_term(lhs);
if let Some(op_pair) = inner.next() {
let op = match op_pair.as_str() {
"==" => CmpOp::Eq,
"!=" => CmpOp::Ne,
"<" => CmpOp::Lt,
"<=" => CmpOp::Lte,
">" => CmpOp::Gt,
">=" => CmpOp::Gte,
_ => unreachable!(),
};
let rhs = parse_term(inner.next().unwrap());
// Map to your Atom constraint form (LHS must be x[idx]):
let (index, rhs_val) = match (lhs_term, rhs) {
(Term::Idx(i), Term::Const(c)) => (i, Rhs::Const(c)),
(Term::Idx(i1), Term::Idx(i2)) => (i1, Rhs::Idx(i2)),
// If you want to allow const OP idx or const OP const, handle here (flip, etc.)
other => panic!("unsupported comparison pattern: {:?}", other),
};
Pred::Atom(Atom {
index,
op,
rhs: rhs_val,
})
} else {
// A bare primary is treated as a boolean atom; you can decide policy.
// Here we treat "x[i]" as (x[i] != 0) and const as (const != 0).
match lhs_term {
Term::Idx(i) => Pred::Atom(Atom {
index: i,
op: CmpOp::Ne,
rhs: Rhs::Const(0),
}),
Term::Const(c) => {
if c != 0 {
Pred::Or(vec![])
} else {
Pred::And(vec![])
} // true/false constants if you add Const
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum Term {
Idx(usize),
Const(u8),
}
fn parse_term(pair: Pair<Rule>) -> Term {
match pair.as_rule() {
Rule::atom => parse_term(pair.into_inner().next().unwrap()),
Rule::byte_idx => {
// "x" "[" number "]"
let mut i = pair.into_inner();
let num = i.find(|p| p.as_rule() == Rule::number).unwrap();
Term::Idx(num.as_str().parse::<usize>().unwrap())
}
Rule::byte_const => {
let n = pair.into_inner().next().unwrap(); // number
Term::Const(n.as_str().parse::<u8>().unwrap())
}
Rule::paren => parse_term(pair.into_inner().next().unwrap()),
Rule::primary => parse_term(pair.into_inner().next().unwrap()),
_ => unreachable!("term {:?}", pair.as_rule()),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_and() {
let pred = parse_expr("x[100] < x[300] && x[200] == 2 || ! (x[5] >= 57)").unwrap();
// `pred` is a Pred::Or with an And on the left and a Not on the right,
// with Atoms inside.
}
}

View File

@@ -296,14 +296,14 @@ mod sha2 {
fn hash(&self, data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(data);
super::Hash::new(hasher.finalize().as_ref())
super::Hash::new(hasher.finalize().as_slice())
}
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
let mut hasher = ::sha2::Sha256::default();
hasher.update(prefix);
hasher.update(data);
super::Hash::new(hasher.finalize().as_ref())
super::Hash::new(hasher.finalize().as_slice())
}
}
}

41
crates/core/src/json.pest Normal file
View File

@@ -0,0 +1,41 @@
// pest. The Elegant Parser
// Copyright (c) 2018 Dragoș Tiselice
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.
//! A parser for JSON file.
//!
//! And this is a example for JSON parser.
json = _{ SOI ~ value ~ eoi }
eoi = _{ !ANY }
/// Matches object, e.g.: `{ "foo": "bar" }`
/// Foobar
object = { "{" ~ pair ~ (pair)* ~ "}" | "{" ~ "}" }
pair = { quoted_string ~ ":" ~ value ~ (",")? }
array = { "[" ~ value ~ ("," ~ value)* ~ "]" | "[" ~ "]" }
//////////////////////
/// Matches value, e.g.: `"foo"`, `42`, `true`, `null`, `[]`, `{}`.
//////////////////////
value = _{ quoted_string | number | object | array | bool | null }
quoted_string = _{ "\"" ~ string ~ "\"" }
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ string)? }
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }
number = @{ "-"? ~ int ~ ("." ~ ASCII_DIGIT+ ~ exp? | exp)? }
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }
exp = @{ ("E" | "e") ~ ("+" | "-")? ~ ASCII_DIGIT+ }
bool = { "true" | "false" }
null = { "null" }
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }

760
crates/core/src/json.rs Normal file
View File

@@ -0,0 +1,760 @@
//!
use crate::predicates::Pred;
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pest_derive::Parser;
use pest_meta::{ast, parser as meta, parser::consume_rules, validator};
#[cfg(test)]
mod test {
use core::panic;
use std::cmp::{max, min};
use crate::{
config::prove::ProveConfig,
predicates::{eval_pred, is_unicode, Atom, CmpOp, Compiler, Rhs},
};
use super::*;
use mpz_circuits::{
evaluate,
ops::{all, any},
};
use pest_meta::ast::Expr;
use rangeset::RangeSet;
const MAX_LEN: usize = 999_999;
#[derive(Debug, Clone)]
enum Ex {
RepEx(Rep),
RepExactEx(RepExact),
SeqEx(Seq),
StrEx(Str),
ChoiceEx(Choice),
NegPredEx(NegPred),
OptEx(Opt),
// An expression which must be replaced with a copy of the rule.
NestedEx,
#[allow(non_camel_case_types)]
ASCII_NONZERO_DIGIT,
#[allow(non_camel_case_types)]
ASCII_DIGIT,
#[allow(non_camel_case_types)]
// A single Unicode character
ANY,
#[allow(non_camel_case_types)]
ASCII_HEX_DIGIT,
}
impl Ex {
fn min_len(&self) -> usize {
match self {
Ex::RepEx(e) => 0,
Ex::RepExactEx(e) => e.0 .0.min_len() * e.0 .1 as usize,
Ex::StrEx(e) => e.0.len(),
Ex::SeqEx(e) => e.0.min_len() + e.1.min_len(),
Ex::ChoiceEx(e) => min(e.0.min_len(), e.1.min_len()),
Ex::NegPredEx(e) => 0,
Ex::ASCII_NONZERO_DIGIT => 1,
Ex::ASCII_DIGIT => 1,
Ex::ANY => 1,
Ex::ASCII_HEX_DIGIT => 1,
Ex::OptEx(e) => 0,
Ex::NestedEx => 0,
_ => unimplemented!(),
}
}
fn max_len(&self) -> usize {
match self {
Ex::RepEx(e) => MAX_LEN,
Ex::RepExactEx(e) => e.0 .0.max_len() * e.0 .1 as usize,
Ex::StrEx(e) => e.0.len(),
Ex::SeqEx(e) => e.0.max_len() + e.1.max_len(),
Ex::ChoiceEx(e) => max(e.0.max_len(), e.1.max_len()),
Ex::NegPredEx(e) => 0,
Ex::ASCII_NONZERO_DIGIT => 1,
Ex::ASCII_DIGIT => 1,
Ex::ANY => 4,
Ex::ASCII_HEX_DIGIT => 1,
Ex::OptEx(e) => e.0.max_len(),
Ex::NestedEx => 0,
_ => unimplemented!(),
}
}
}
#[derive(Debug, Clone)]
struct Rep(Box<Ex>);
#[derive(Debug, Clone)]
struct RepExact((Box<Ex>, u32));
#[derive(Debug, Clone)]
struct Str(String);
#[derive(Debug, Clone)]
struct Seq(Box<Ex>, Box<Ex>);
#[derive(Debug, Clone)]
struct Choice(Box<Ex>, Box<Ex>);
#[derive(Debug, Clone)]
struct NegPred(Box<Ex>);
#[derive(Debug, Clone)]
struct Opt(Box<Ex>);
struct Rule {
name: String,
pub ex: Ex,
}
/// Builds the rules, returning the final expression.
fn build_rules(ast_rules: &[ast::Rule]) -> Ex {
let mut rules = Vec::new();
// build from the bottom up
let iter = ast_rules.iter().rev();
for r in iter {
println!("building rule with name {:?}", r.name);
let ex = build_expr(&r.expr, &rules, &r.name, false);
// TODO deal with recursive rules
rules.push(Rule {
name: r.name.clone(),
ex,
});
}
let ex = rules.last().unwrap().ex.clone();
ex
}
/// Builds expression from pest expression.
/// passes in current rule's name to deal with recursion.
/// depth is used to prevent infinite recursion.
fn build_expr(exp: &Expr, rules: &[Rule], this_name: &String, is_nested: bool) -> Ex {
match exp {
Expr::Rep(exp) => {
Ex::RepEx(Rep(Box::new(build_expr(exp, rules, this_name, is_nested))))
}
Expr::RepExact(exp, count) => Ex::RepExactEx(RepExact((
Box::new(build_expr(exp, rules, this_name, is_nested)),
*count,
))),
Expr::Str(str) => Ex::StrEx(Str(str.clone())),
Expr::NegPred(exp) => Ex::NegPredEx(NegPred(Box::new(build_expr(
exp, rules, this_name, is_nested,
)))),
Expr::Seq(a, b) => {
//
let a = build_expr(a, rules, this_name, is_nested);
Ex::SeqEx(Seq(
Box::new(a),
Box::new(build_expr(b, rules, this_name, is_nested)),
))
}
Expr::Choice(a, b) => Ex::ChoiceEx(Choice(
Box::new(build_expr(a, rules, this_name, is_nested)),
Box::new(build_expr(b, rules, this_name, is_nested)),
)),
Expr::Opt(exp) => {
Ex::OptEx(Opt(Box::new(build_expr(exp, rules, this_name, is_nested))))
}
Expr::Ident(ident) => {
let ex = match ident.as_str() {
"ASCII_NONZERO_DIGIT" => Ex::ASCII_NONZERO_DIGIT,
"ASCII_DIGIT" => Ex::ASCII_DIGIT,
"ANY" => Ex::ANY,
"ASCII_HEX_DIGIT" => Ex::ASCII_HEX_DIGIT,
_ => {
if *ident == *this_name {
return Ex::NestedEx;
}
for rule in rules {
return rule.ex.clone();
}
panic!("couldnt find rule {:?}", ident);
}
};
ex
}
_ => unimplemented!(),
}
}
// This method must be called when we know that there is enough
// data remained starting from the offset to match the expression
// at least once.
//
// returns the predicate and the offset from which the next expression
// should be matched.
// Returns multiple predicates if the expression caused multiple branches.
// A top level expr always returns a single predicate, in which all branches
// are coalesced.
fn expr_to_pred(
exp: &Ex,
offset: usize,
data_len: usize,
is_top_level: bool,
) -> Vec<(Pred, usize)> {
// if is_top_level {
// println!("top level exps {:?}", exp);
// } else {
// println!("Non-top level exps {:?}", exp);
// }
match exp {
Ex::SeqEx(s) => {
let a = &s.0;
let b = &s.1;
if is_top_level && (offset + a.max_len() + b.max_len() < data_len) {
panic!();
}
if offset + a.min_len() + b.min_len() > data_len {
panic!();
}
// The first expression must not try to match in the
// data of the next expression
let pred1 = expr_to_pred(a, offset, data_len - b.min_len(), false);
// interlace all branches
let mut interlaced = Vec::new();
for (p1, offset) in pred1.iter() {
// if the seq expr was top-level, the 2nd expr becomes top-level
let mut pred2 = expr_to_pred(b, *offset, data_len, is_top_level);
for (p2, offser_inner) in pred2.iter() {
let pred = Pred::And(vec![p1.clone(), p2.clone()]);
interlaced.push((pred, *offser_inner));
}
}
if is_top_level {
// coalesce all branches
let preds: Vec<Pred> = interlaced.into_iter().map(|(a, _b)| a).collect();
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
vec![(Pred::Or(preds), 0)]
}
} else {
interlaced
}
}
Ex::ChoiceEx(s) => {
let a = &s.0;
let b = &s.1;
let mut skip_a = false;
let mut skip_b = false;
if is_top_level {
if offset + a.max_len() != data_len {
skip_a = true
}
if offset + b.max_len() != data_len {
skip_b = true;
}
} else {
// if not top level, we may skip an expression when it will
// overflow the data len
if offset + a.min_len() > data_len {
skip_a = true
}
if offset + b.min_len() > data_len {
skip_b = true
}
}
if skip_a && skip_b {
panic!();
}
let mut preds_a = Vec::new();
let mut preds_b = Vec::new();
if !skip_a {
preds_a = expr_to_pred(a, offset, data_len, is_top_level);
}
if !skip_b {
preds_b = expr_to_pred(b, offset, data_len, is_top_level);
}
// combine all branches
let mut combined = Vec::new();
if preds_a.is_empty() {
combined = preds_b.clone();
} else if preds_b.is_empty() {
combined = preds_a.clone();
} else {
assert!(!(preds_a.is_empty() && preds_b.is_empty()));
combined.append(&mut preds_a);
combined.append(&mut preds_b);
}
if is_top_level {
// coalesce all branches
let preds: Vec<Pred> = combined.into_iter().map(|(a, _b)| a).collect();
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
vec![(Pred::Or(preds), 0)]
}
} else {
combined
}
}
Ex::RepEx(r) => {
let e = &r.0;
if offset + e.min_len() > data_len {
if is_top_level {
panic!();
}
// zero matches
return vec![];
}
let mut interlaced = Vec::new();
let mut preds = expr_to_pred(&e, offset, data_len, false);
// for (i, (pred, depth)) in preds.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
// Append single matches.
interlaced.append(&mut preds.clone());
let mut was_found = true;
while was_found {
was_found = false;
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
if offset_outer + e.min_len() > data_len {
// cannot match any more
continue;
}
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
// for (i, (pred, depth)) in preds_inner.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
for (pred_inner, offset_inner) in preds_inner {
let pred = (
Pred::And(vec![pred_outer.clone(), pred_inner]),
offset_inner,
);
preds.push(pred);
was_found = true;
}
}
interlaced.append(&mut preds.clone());
}
// for (i, (pred, depth)) in interlaced.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
if is_top_level {
// drop all branches which do not match exactly at the data length
// border and coalesce the rest
let preds: Vec<Pred> = interlaced
.into_iter()
.filter(|(_a, b)| *b == data_len)
.map(|(a, _b)| a)
.collect();
if preds.is_empty() {
panic!()
}
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
// coalesce all branches
vec![(Pred::Or(preds), 0)]
}
} else {
interlaced
}
}
Ex::RepExactEx(r) => {
let e = &r.0 .0;
let count = r.0 .1;
assert!(count > 0);
if is_top_level && (offset + e.max_len() * count as usize <= data_len) {
panic!();
}
let mut preds = expr_to_pred(&e, offset, data_len, false);
for i in 1..count {
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
if offset_outer + e.min_len() > data_len {
// cannot match any more
continue;
}
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
for (pred_inner, offset_inner) in preds_inner {
let pred = (
Pred::And(vec![pred_outer.clone(), pred_inner]),
offset_inner,
);
preds.push(pred);
}
}
}
if is_top_level {
// drop all branches which do not match exactly at the data length
// border and coalesce the rest
let preds: Vec<Pred> = preds
.into_iter()
.filter(|(_a, b)| *b != data_len)
.map(|(a, _b)| a)
.collect();
if preds.is_empty() {
panic!()
}
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
// coalesce all branches
vec![(Pred::Or(preds), 0)]
}
} else {
preds
}
}
Ex::NegPredEx(e) => {
assert!(offset <= data_len);
if offset == data_len {
// the internal expression cannot be match since there is no data left,
// this means that the negative expression matched
if is_top_level {
panic!("always true predicate doesnt make sense")
}
// TODO this is hacky.
return vec![(Pred::True, offset)];
}
let e = &e.0;
let preds = expr_to_pred(&e, offset, data_len, is_top_level);
let preds: Vec<Pred> = preds.into_iter().map(|(a, _b)| a).collect();
let len = preds.len();
// coalesce all branches, offset doesnt matter since those
// offset will never be used anymore.
let pred = if preds.len() == 1 {
Pred::Not(Box::new(preds[0].clone()))
} else {
Pred::Not(Box::new(Pred::Or(preds)))
};
if is_top_level && len == 0 {
panic!()
}
// all offset if negative predicate are ignored since no matching
// will be done from those offsets.
vec![(pred, offset)]
}
Ex::OptEx(e) => {
let e = &e.0;
if is_top_level {
return vec![(Pred::True, 0)];
}
// add an always-matching branch
let mut preds = vec![(Pred::True, offset)];
if e.min_len() + offset <= data_len {
// try to match only if there is enough data
let mut p = expr_to_pred(&e, offset, data_len, is_top_level);
preds.append(&mut p);
}
preds
}
Ex::StrEx(s) => {
if is_top_level && offset + s.0.len() != data_len {
panic!();
}
let mut preds = Vec::new();
for (idx, byte) in s.0.clone().into_bytes().iter().enumerate() {
let a = Atom {
index: offset + idx,
op: CmpOp::Eq,
rhs: Rhs::Const(*byte),
};
preds.push(Pred::Atom(a));
}
if preds.len() == 1 {
vec![(preds[0].clone(), offset + s.0.len())]
} else {
vec![(Pred::And(preds), offset + s.0.len())]
}
}
Ex::ASCII_NONZERO_DIGIT => {
if is_top_level && (offset + 1 != data_len) {
panic!();
}
let gte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Gte,
rhs: Rhs::Const(49u8),
});
let lte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
vec![(Pred::And(vec![gte, lte]), offset + 1)]
}
Ex::ASCII_DIGIT => {
if is_top_level && (offset + 1 != data_len) {
panic!();
}
let gte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
vec![(Pred::And(vec![gte, lte]), offset + 1)]
}
Ex::ANY => {
if is_top_level && (offset + 1 > data_len) {
panic!();
}
let start = offset;
let end = min(offset + 4, data_len);
let mut branches = Vec::new();
for branch_end in start + 1..end {
branches.push((is_unicode(RangeSet::from(start..branch_end)), branch_end))
}
if is_top_level {
assert!(branches.len() == 1);
}
branches
}
_ => unimplemented!(),
}
}
#[test]
fn test_json_int() {
use rand::{distr::Alphanumeric, rng, Rng};
let grammar = include_str!("json_int.pest");
// Parse the grammar file into Pairs (the grammars own parse tree)
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
// 4) Convert the parsed pairs into the stable AST representation
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
let exp = build_rules(&rules_ast);
// 5) Inspect the AST however you like For a quick look, the Debug print is the
// safest (works across versions)
for rule in &rules_ast {
println!("{:#?}", rule);
}
const LENGTH: usize = 7; // Adjustable constant
let pred = expr_to_pred(&exp, 0, LENGTH, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
let circ = Compiler::new().compile(&pred);
println!("{:?} and gates", circ.and_count());
for i in 0..1000000 {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(LENGTH)
.map(char::from)
.collect();
let out = eval_pred(pred, s.as_bytes());
let is_int = s.chars().all(|c| c.is_ascii_digit()) && !s.starts_with('0');
if out != is_int {
println!("failed at index {:?} with {:?}", i, s);
}
assert_eq!(out, is_int)
}
}
#[test]
fn test_json_str() {
use rand::{distr::Alphanumeric, rng, Rng};
const LENGTH: usize = 10; // Adjustable constant
let grammar = include_str!("json_str.pest");
// Parse the grammar file into Pairs (the grammars own parse tree)
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
// 4) Convert the parsed pairs into the stable AST representation
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
for rule in &rules_ast {
println!("{:#?}", rule);
}
let exp = build_rules(&rules_ast);
for len in LENGTH..LENGTH + 7 {
let pred = expr_to_pred(&exp, 0, len, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
let circ = Compiler::new().compile(pred);
println!(
"JSON string length: {:?}; circuit AND gate count {:?}",
len,
circ.and_count()
);
}
}
#[test]
fn test_choice() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Choice(Box::new(a), Box::new(b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 1, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_seq() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Seq(Box::new(a), Box::new(b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 2, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_rep() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Rep(Box::new(a)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 3, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_rep_choice() {
const LENGTH: usize = 5; // Adjustable constant
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
// Number of predicates needed to represent the expressions.
let a_weight = 2usize;
let b_weight = 2usize;
let rep_a = Expr::Rep(Box::new(a));
let rep_b = Expr::Rep(Box::new(b));
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Choice(Box::new(rep_a), Box::new(rep_b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, LENGTH, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {}", pred);
// This is for sanity that no extra predicates are being added.
assert_eq!(pred.leaves(), a_weight * LENGTH + b_weight * LENGTH);
}
#[test]
fn test_neg_choice() {
let a = Expr::Str("4".to_string());
let b = Expr::Str("5".to_string());
let choice = Expr::Choice(Box::new(a), Box::new(b));
let neg_choice = Expr::NegPred(Box::new(choice));
let c = Expr::Str("a".to_string());
let d = Expr::Str("BC".to_string());
let choice2 = Expr::Choice(Box::new(c), Box::new(d));
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Seq(Box::new(neg_choice), Box::new(choice2)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 2, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
assert_eq!(pred.leaves(), 4);
}
}

View File

@@ -0,0 +1,3 @@
// Copied from pest.json
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }

View File

@@ -0,0 +1,6 @@
// Copied from pest.json
// Replaced "string" with "X" to avoid recursion.
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ "X")? }
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }

View File

@@ -14,6 +14,9 @@ pub mod webpki;
pub use rangeset;
pub mod config;
pub(crate) mod display;
//pub mod grammar;
pub mod json;
pub mod predicates;
use serde::{Deserialize, Serialize};

View File

@@ -0,0 +1,790 @@
//! Predicate and compiler.
use std::{collections::HashMap, fmt};
use mpz_circuits::{itybity::ToBits, ops, Circuit, CircuitBuilder, Feed, Node};
use rangeset::RangeSet;
/// ddd
#[derive(Debug, Clone)]
pub(crate) enum Pred {
And(Vec<Pred>),
Or(Vec<Pred>),
Not(Box<Pred>),
Atom(Atom),
// An always-true predicate.
True,
// An always-false predicate.
False,
}
impl Pred {
/// Returns sorted unique byte indices of this predicate.
pub(crate) fn indices(&self) -> Vec<usize> {
let mut indices = self.indices_internal(self);
indices.sort_unstable();
indices.dedup();
indices
}
// Returns the number of leaves (i.e atoms) the AST of this predicate has.
pub(crate) fn leaves(&self) -> usize {
match self {
Pred::And(vec) => vec.iter().map(|p| p.leaves()).sum(),
Pred::Or(vec) => vec.iter().map(|p| p.leaves()).sum(),
Pred::Not(p) => p.leaves(),
Pred::Atom(atom) => 1,
Pred::True => 0,
Pred::False => 0,
}
}
/// Returns all byte indices of the given `pred`icate.
fn indices_internal(&self, pred: &Pred) -> Vec<usize> {
match pred {
Pred::And(vec) => vec
.iter()
.flat_map(|p| self.indices_internal(p))
.collect::<Vec<_>>(),
Pred::Or(vec) => vec
.iter()
.flat_map(|p| self.indices_internal(p))
.collect::<Vec<_>>(),
Pred::Not(p) => self.indices_internal(p),
Pred::Atom(atom) => {
let mut indices = Vec::new();
indices.push(atom.index);
if let Rhs::Idx(idx) = atom.rhs {
indices.push(idx);
}
indices
}
Pred::True => vec![],
Pred::False => vec![],
}
}
}
impl fmt::Display for Pred {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with_indent(f, 0)
}
}
impl Pred {
fn fmt_with_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
// helper to write the current indentation
fn pad(f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
// 2 spaces per level; tweak as you like
write!(f, "{:indent$}", "", indent = indent * 2)
}
match self {
Pred::And(preds) => {
pad(f, indent)?;
writeln!(f, "And(")?;
for p in preds {
p.fmt_with_indent(f, indent + 1)?;
}
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Or(preds) => {
pad(f, indent)?;
writeln!(f, "Or(")?;
for p in preds {
p.fmt_with_indent(f, indent + 1)?;
}
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Not(p) => {
pad(f, indent)?;
writeln!(f, "Not(")?;
p.fmt_with_indent(f, indent + 1)?;
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Atom(a) => {
pad(f, indent)?;
writeln!(f, "Atom({:?})", a)
}
Pred::True => {
pad(f, indent)?;
writeln!(f, "True")
}
Pred::False => {
pad(f, indent)?;
writeln!(f, "False")
}
}
}
}
/// Atomic predicate of the form:
/// x[index] (op) rhs
#[derive(Debug, Clone)]
pub struct Atom {
/// Left-hand side byte index `i` (x_i).
pub index: usize,
/// Comparison operator.
pub op: CmpOp,
/// Right-hand side operand (constant or x_j).
pub rhs: Rhs,
}
/// ddd
#[derive(Debug, Clone)]
pub(crate) enum CmpOp {
Eq, // ==
Ne, // !=
Gt, // >
Gte, // >=
Lt, // <
Lte, // <=
}
/// RHS of a comparison
#[derive(Debug, Clone)]
pub enum Rhs {
/// Byte at index
Idx(usize),
/// Literal constant.
Const(u8),
}
/// Compiles a predicate into a circuit.
pub struct Compiler {
/// A <byte index, circuit feeds> map.
map: HashMap<usize, [Node<Feed>; 8]>,
}
impl Compiler {
pub(crate) fn new() -> Self {
Self {
map: HashMap::new(),
}
}
/// Compiles the given predicate into a circuit, consuming the
/// compiler.
pub(crate) fn compile(&mut self, pred: &Pred) -> Circuit {
let mut builder = CircuitBuilder::new();
for idx in pred.indices() {
let feeds = (0..8).map(|_| builder.add_input()).collect::<Vec<_>>();
self.map.insert(idx, feeds.try_into().unwrap());
}
let out = self.process(&mut builder, pred);
builder.add_output(out);
builder.build().unwrap()
}
// Processes a single predicate.
fn process(&mut self, builder: &mut CircuitBuilder, pred: &Pred) -> Node<Feed> {
match pred {
Pred::And(vec) => {
let out = vec
.iter()
.map(|p| self.process(builder, p))
.collect::<Vec<_>>();
ops::all(builder, &out)
}
Pred::Or(vec) => {
let out = vec
.iter()
.map(|p| self.process(builder, p))
.collect::<Vec<_>>();
ops::any(builder, &out)
}
Pred::Not(p) => {
let pred_out = self.process(builder, p);
let inv = ops::inv(builder, [pred_out]);
inv[0]
}
Pred::Atom(atom) => {
let lhs = self.map.get(&atom.index).unwrap().clone();
let rhs = match atom.rhs {
Rhs::Const(c) => const_feeds(builder, c),
Rhs::Idx(s) => self.map.get(&s).unwrap().clone(),
};
match atom.op {
CmpOp::Eq => ops::eq(builder, lhs, rhs),
CmpOp::Ne => ops::neq(builder, lhs, rhs),
CmpOp::Lt => ops::lt(builder, lhs, rhs),
CmpOp::Lte => ops::lte(builder, lhs, rhs),
CmpOp::Gt => ops::gt(builder, lhs, rhs),
CmpOp::Gte => ops::gte(builder, lhs, rhs),
}
}
Pred::True => builder.get_const_one(),
Pred::False => builder.get_const_zero(),
}
}
}
// Returns circuit feeds for the given constant u8 value.
fn const_feeds(builder: &CircuitBuilder, cnst: u8) -> [Node<Feed>; 8] {
cnst.iter_lsb0()
.map(|b| {
if b {
builder.get_const_one()
} else {
builder.get_const_zero()
}
})
.collect::<Vec<_>>()
.try_into()
.expect("u8 has 8 feeds")
}
// Evaluates the predicate on the input `data`.
pub(crate) fn eval_pred(pred: &Pred, data: &[u8]) -> bool {
match pred {
Pred::And(vec) => vec.iter().map(|p| eval_pred(p, data)).all(|b| b),
Pred::Or(vec) => vec.iter().map(|p| eval_pred(p, data)).any(|b| b),
Pred::Not(p) => !eval_pred(p, data),
Pred::Atom(atom) => {
let lhs = data[atom.index];
let rhs = match atom.rhs {
Rhs::Const(c) => c,
Rhs::Idx(s) => data[s],
};
match atom.op {
CmpOp::Eq => lhs == rhs,
CmpOp::Ne => lhs != rhs,
CmpOp::Lt => lhs < rhs,
CmpOp::Lte => lhs <= rhs,
CmpOp::Gt => lhs > rhs,
CmpOp::Gte => lhs >= rhs,
}
}
Pred::True => true,
Pred::False => true,
}
}
/// Builds a predicate that an ascii integer is contained in the ranges.
fn is_ascii_integer(range: RangeSet<usize>) -> Pred {
let mut preds = Vec::new();
for idx in range.iter() {
let lte = Pred::Atom(Atom {
index: idx,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
let gte = Pred::Atom(Atom {
index: idx,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
preds.push(Pred::And(vec![lte, gte]));
}
Pred::And(preds)
}
/// Builds a predicate that a valid HTTP header value is contained in the
/// ranges.
fn is_valid_http_header_value(range: RangeSet<usize>) -> Pred {
let mut preds = Vec::new();
for idx in range.iter() {
let ne = Pred::Atom(Atom {
index: idx,
op: CmpOp::Ne,
// ascii code for carriage return \r
rhs: Rhs::Const(13u8),
});
preds.push(ne);
}
Pred::And(preds)
}
/// Builds a predicate that a valid JSON string is contained in the
/// ranges.
fn is_valid_json_string(range: RangeSet<usize>) -> Pred {
assert!(
range.len_ranges() == 1,
"only a contiguous range is allowed"
);
const BACKSLASH: u8 = 92;
// check if all unicode chars are allowed
let mut preds = Vec::new();
// Find all /u sequences
for (i, idx) in range.iter().enumerate() {
if i == range.len() - 1 {
// if this is a last char, skip it
continue;
}
let is_backslash = Pred::Atom(Atom {
index: idx,
op: CmpOp::Eq,
rhs: Rhs::Const(BACKSLASH),
});
}
Pred::And(preds)
}
// Returns a predicate that a unicode char is contained in the range
pub(crate) fn is_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() <= 4);
match range.len() {
1 => is_1_byte_unicode(range.max().unwrap()),
2 => is_2_byte_unicode(range),
3 => is_3_byte_unicode(range),
4 => is_4_byte_unicode(range),
_ => unimplemented!(),
}
}
fn is_1_byte_unicode(pos: usize) -> Pred {
Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(127u8),
})
}
fn is_2_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 2);
let mut iter = range.iter();
// should be 110xxxxx
let first = iter.next().unwrap();
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xC0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xDF),
});
let second = iter.next().unwrap();
Pred::And(vec![lte, gte, is_unicode_continuation(second)])
}
fn is_3_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 3);
let mut iter = range.iter();
let first = iter.next().unwrap();
// should be 1110xxxx
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xE0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xEF),
});
let second = iter.next().unwrap();
let third = iter.next().unwrap();
Pred::And(vec![
lte,
gte,
is_unicode_continuation(second),
is_unicode_continuation(third),
])
}
fn is_4_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 4);
let mut iter = range.iter();
let first = iter.next().unwrap();
// should be 11110xxx
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xF0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xF7),
});
let second = iter.next().unwrap();
let third = iter.next().unwrap();
let fourth = iter.next().unwrap();
Pred::And(vec![
lte,
gte,
is_unicode_continuation(second),
is_unicode_continuation(third),
is_unicode_continuation(fourth),
])
}
fn is_unicode_continuation(pos: usize) -> Pred {
// should be 10xxxxxx
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(0x80),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(0xBF),
});
Pred::And(vec![lte, gte])
}
fn is_ascii_hex_digit(pos: usize) -> Pred {
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
let is_digit = Pred::And(vec![lte, gte]);
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(65u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(70u8),
});
let is_upper = Pred::And(vec![lte, gte]);
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(97u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(102u8),
});
let is_lower = Pred::And(vec![lte, gte]);
Pred::Or(vec![is_digit, is_lower, is_upper])
}
fn is_ascii_lowercase(pos: usize) -> Pred {
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
Pred::And(vec![lte, gte])
}
#[cfg(test)]
mod test {
use super::*;
use mpz_circuits::evaluate;
use rand::rng;
#[test]
fn test_and() {
let pred = Pred::And(vec![
Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
}),
Pred::Atom(Atom {
index: 200,
op: CmpOp::Eq,
rhs: Rhs::Const(2u8),
}),
]);
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [1u8, 2, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3, 3]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_or() {
let pred = Pred::Or(vec![
Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
}),
Pred::Atom(Atom {
index: 200,
op: CmpOp::Eq,
rhs: Rhs::Const(2u8),
}),
]);
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [1u8, 0, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3, 0]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_not() {
let pred = Pred::Not(Box::new(Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
})));
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
assert_eq!(out, false);
}
// Tests when RHS is a const.
#[test]
fn test_rhs_const() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Const(22u8),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, 23u8).unwrap();
assert_eq!(out, false);
}
// Tests when RHS is an index.
#[test]
fn test_rhs_idx() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(200),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, 5u8, 10u8).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, 23u8, 5u8).unwrap();
assert_eq!(out, false);
}
// Tests when same index is used in the predicate.
#[test]
fn test_same_idx() {
let pred1 = Pred::Atom(Atom {
index: 100,
op: CmpOp::Eq,
rhs: Rhs::Idx(100),
});
let pred2 = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(100),
});
let circ = Compiler::new().compile(&pred1);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, true);
let circ = Compiler::new().compile(&pred2);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_eq() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Eq,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 5]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_neq() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Ne,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 6]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_gt() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Gt,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [7u8, 6]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_gte() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Gte,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [7u8, 7]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [0u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_lt() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [2u8, 7]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_lte() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lte,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [2u8, 2]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_is_ascii_integer() {
let text = "text with integers 123456 text";
let pos = text.find("123456").unwrap();
let pred = is_ascii_integer(RangeSet::from(pos..pos + 6));
let bytes: &[u8] = text.as_bytes();
let out = eval_pred(&pred, bytes);
assert_eq!(out, true);
let out = eval_pred(&pred, &[&[0u8], bytes].concat());
assert_eq!(out, false);
}
#[test]
fn test_is_valid_http_header_value() {
let valid = "valid header value";
let invalid = "invalid header \r value";
let pred = is_valid_http_header_value(RangeSet::from(0..valid.len()));
let out: bool = eval_pred(&pred, valid.as_bytes());
assert_eq!(out, true);
let pred = is_valid_http_header_value(RangeSet::from(0..invalid.len()));
let out = eval_pred(&pred, invalid.as_bytes());
assert_eq!(out, false);
}
#[test]
fn test_is_unicode() {
use rand::{distr::Alphanumeric, rng, Rng};
let mut rng = rng();
for _ in 0..1000000 {
let mut s = String::from("HelloWorld");
let insert_pos = 5; // logical character index (after "Hello")
let byte_index = s
.char_indices()
.nth(insert_pos)
.map(|(i, _)| i)
.unwrap_or_else(|| s.len());
// Pick a random Unicode scalar value (0x0000..=0x10FFFF)
// Retry if it's in the surrogate range (U+D800..=U+DFFF)
let c = loop {
let code = rng.random_range(0x0000u32..=0x10FFFF);
if !(0xD800..=0xDFFF).contains(&code) {
if let Some(ch) = char::from_u32(code) {
break ch;
}
}
};
let mut buf = [0u8; 4]; // max UTF-8 length
let encoded = c.encode_utf8(&mut buf); // returns &str
let len = encoded.len();
s.insert_str(byte_index, &c.to_string());
let pred = is_unicode(RangeSet::from(byte_index..byte_index + len));
let out = eval_pred(&pred, s.as_bytes());
assert_eq!(out, true);
}
let bad_unicode = 255u8;
let pred = is_unicode(RangeSet::from(0..1));
let out = eval_pred(&pred, &[bad_unicode]);
assert_eq!(out, false);
}
}

View File

@@ -26,11 +26,7 @@ mod tls;
use std::{fmt, ops::Range};
use rangeset::{
iter::RangeIterator,
ops::{Index, Set},
set::RangeSet,
};
use rangeset::{Difference, IndexRanges, RangeSet, Union};
use serde::{Deserialize, Serialize};
use crate::connection::TranscriptLength;
@@ -110,14 +106,8 @@ impl Transcript {
}
Some(
Subsequence::new(
idx.clone(),
data.index(idx).fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
)
.expect("data is same length as index"),
Subsequence::new(idx.clone(), data.index_ranges(idx))
.expect("data is same length as index"),
)
}
@@ -139,11 +129,11 @@ impl Transcript {
let mut sent = vec![0; self.sent.len()];
let mut received = vec![0; self.received.len()];
for range in sent_idx.iter() {
for range in sent_idx.iter_ranges() {
sent[range.clone()].copy_from_slice(&self.sent[range]);
}
for range in recv_idx.iter() {
for range in recv_idx.iter_ranges() {
received[range.clone()].copy_from_slice(&self.received[range]);
}
@@ -196,20 +186,12 @@ pub struct CompressedPartialTranscript {
impl From<PartialTranscript> for CompressedPartialTranscript {
fn from(uncompressed: PartialTranscript) -> Self {
Self {
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
Vec::new(),
|mut acc, s| {
acc.extend_from_slice(s);
acc
},
),
sent_authed: uncompressed
.sent
.index_ranges(&uncompressed.sent_authed_idx),
received_authed: uncompressed
.received
.index(&uncompressed.received_authed_idx)
.fold(Vec::new(), |mut acc, s| {
acc.extend_from_slice(s);
acc
}),
.index_ranges(&uncompressed.received_authed_idx),
sent_idx: uncompressed.sent_authed_idx,
recv_idx: uncompressed.received_authed_idx,
sent_total: uncompressed.sent.len(),
@@ -225,7 +207,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.sent_idx.iter() {
for range in compressed.sent_idx.iter_ranges() {
sent[range.clone()]
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
offset += range.len();
@@ -233,7 +215,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
let mut offset = 0;
for range in compressed.recv_idx.iter() {
for range in compressed.recv_idx.iter_ranges() {
received[range.clone()]
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
offset += range.len();
@@ -322,16 +304,12 @@ impl PartialTranscript {
/// Returns the index of sent data which haven't been authenticated.
pub fn sent_unauthed(&self) -> RangeSet<usize> {
(0..self.sent.len())
.difference(&self.sent_authed_idx)
.into_set()
(0..self.sent.len()).difference(&self.sent_authed_idx)
}
/// Returns the index of received data which haven't been authenticated.
pub fn received_unauthed(&self) -> RangeSet<usize> {
(0..self.received.len())
.difference(&self.received_authed_idx)
.into_set()
(0..self.received.len()).difference(&self.received_authed_idx)
}
/// Returns an iterator over the authenticated data in the transcript.
@@ -341,7 +319,7 @@ impl PartialTranscript {
Direction::Received => (&self.received, &self.received_authed_idx),
};
authed.iter_values().map(move |i| data[i])
authed.iter().map(|i| data[i])
}
/// Unions the authenticated data of this transcript with another.
@@ -361,20 +339,24 @@ impl PartialTranscript {
"received data are not the same length"
);
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
for range in other
.sent_authed_idx
.difference(&self.sent_authed_idx)
.iter_ranges()
{
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
}
for range in other
.received_authed_idx
.difference(&self.received_authed_idx)
.iter_ranges()
{
self.received[range.clone()].copy_from_slice(&other.received[range]);
}
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
self.received_authed_idx
.union_mut(&other.received_authed_idx);
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
}
/// Unions an authenticated subsequence into this transcript.
@@ -386,11 +368,11 @@ impl PartialTranscript {
match direction {
Direction::Sent => {
seq.copy_to(&mut self.sent);
self.sent_authed_idx.union_mut(&seq.idx);
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
}
Direction::Received => {
seq.copy_to(&mut self.received);
self.received_authed_idx.union_mut(&seq.idx);
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
}
}
}
@@ -401,10 +383,10 @@ impl PartialTranscript {
///
/// * `value` - The value to set the unauthenticated bytes to
pub fn set_unauthed(&mut self, value: u8) {
for range in self.sent_unauthed().iter() {
for range in self.sent_unauthed().iter_ranges() {
self.sent[range].fill(value);
}
for range in self.received_unauthed().iter() {
for range in self.received_unauthed().iter_ranges() {
self.received[range].fill(value);
}
}
@@ -419,13 +401,13 @@ impl PartialTranscript {
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
match direction {
Direction::Sent => {
for r in range.difference(&self.sent_authed_idx) {
self.sent[r].fill(value);
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
self.sent[range].fill(value);
}
}
Direction::Received => {
for r in range.difference(&self.received_authed_idx) {
self.received[r].fill(value);
for range in range.difference(&self.received_authed_idx).iter_ranges() {
self.received[range].fill(value);
}
}
}
@@ -503,7 +485,7 @@ impl Subsequence {
/// Panics if the subsequence ranges are out of bounds.
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
let mut offset = 0;
for range in self.idx.iter() {
for range in self.idx.iter_ranges() {
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
offset += range.len();
}
@@ -628,7 +610,12 @@ mod validation {
mut partial_transcript: CompressedPartialTranscriptUnchecked,
) {
// Change the total to be less than the last range's end bound.
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
let end = partial_transcript
.sent_idx
.iter_ranges()
.next_back()
.unwrap()
.end;
partial_transcript.sent_total = end - 1;

View File

@@ -2,7 +2,7 @@
use std::{collections::HashSet, fmt};
use rangeset::set::ToRangeSet;
use rangeset::{ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{

View File

@@ -1,6 +1,6 @@
use std::{collections::HashMap, fmt};
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
@@ -103,7 +103,7 @@ impl EncodingProof {
}
expected_leaf.clear();
for range in idx.iter() {
for range in idx.iter_ranges() {
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
}
expected_leaf.extend_from_slice(blinder.as_bytes());

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use bimap::BiMap;
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use crate::{
@@ -99,7 +99,7 @@ impl EncodingTree {
let blinder: Blinder = rand::random();
encoding.clear();
for range in idx.iter() {
for range in idx.iter_ranges() {
provider
.provide_encoding(direction, range, &mut encoding)
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;

View File

@@ -1,10 +1,6 @@
//! Transcript proofs.
use rangeset::{
iter::RangeIterator,
ops::{Cover, Set},
set::ToRangeSet,
};
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt};
@@ -148,7 +144,7 @@ impl TranscriptProof {
}
buffer.clear();
for range in idx.iter() {
for range in idx.iter_ranges() {
buffer.extend_from_slice(&plaintext[range]);
}
@@ -370,7 +366,7 @@ impl<'a> TranscriptProofBuilder<'a> {
if idx.is_subset(committed) {
self.query_idx.union(&direction, &idx);
} else {
let missing = idx.difference(committed).into_set();
let missing = idx.difference(committed);
return Err(TranscriptProofBuilderError::new(
BuilderErrorKind::MissingCommitment,
format!(
@@ -586,7 +582,7 @@ impl fmt::Display for TranscriptProofBuilderError {
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rangeset::prelude::*;
use rangeset::RangeSet;
use rstest::rstest;
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};

View File

@@ -324,7 +324,7 @@ fn prepare_zk_proof_input(
hasher.update(&blinder);
let computed_hash = hasher.finalize();
if committed_hash != computed_hash.as_ref() as &[u8] {
if committed_hash != computed_hash.as_slice() {
return Err(anyhow::anyhow!(
"Computed hash does not match committed hash"
));

View File

@@ -9,7 +9,6 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
pub const DEFAULT_MEMORY_PROFILE: bool = false;
pub const DEFAULT_REVEAL_ALL: bool = false;
pub const WARM_UP_BENCH: Bench = Bench {
group: None,
@@ -21,7 +20,6 @@ pub const WARM_UP_BENCH: Bench = Bench {
download_size: 4096,
defer_decryption: true,
memory_profile: false,
reveal_all: true,
};
#[derive(Deserialize)]
@@ -81,8 +79,6 @@ pub struct BenchGroupItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -101,8 +97,6 @@ pub struct BenchItem {
pub defer_decryption: Option<bool>,
#[serde(rename = "memory-profile")]
pub memory_profile: Option<bool>,
#[serde(rename = "reveal-all")]
pub reveal_all: Option<bool>,
}
impl BenchItem {
@@ -138,10 +132,6 @@ impl BenchItem {
if self.memory_profile.is_none() {
self.memory_profile = group.memory_profile;
}
if self.reveal_all.is_none() {
self.reveal_all = group.reveal_all;
}
}
pub fn into_bench(&self) -> Bench {
@@ -155,7 +145,6 @@ impl BenchItem {
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
}
}
}
@@ -175,8 +164,6 @@ pub struct Bench {
pub defer_decryption: bool,
#[serde(rename = "memory-profile")]
pub memory_profile: bool,
#[serde(rename = "reveal-all")]
pub reveal_all: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -22,10 +22,7 @@ pub enum CmdOutput {
GetTests(Vec<String>),
Test(TestOutput),
Bench(BenchOutput),
#[cfg(target_arch = "wasm32")]
Fail {
reason: Option<String>,
},
Fail { reason: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View File

@@ -98,27 +98,14 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
let mut builder = ProveConfig::builder(prover.transcript());
// When reveal_all is false (the default), we exclude 1 byte to avoid the
// reveal-all optimization and benchmark the realistic ZK authentication path.
let reveal_sent_range = if config.reveal_all {
0..sent_len
} else {
0..sent_len.saturating_sub(1)
};
let reveal_recv_range = if config.reveal_all {
0..recv_len
} else {
0..recv_len.saturating_sub(1)
};
builder
.server_identity()
.reveal_sent(&reveal_sent_range)?
.reveal_recv(&reveal_recv_range)?;
.reveal_sent(&(0..sent_len))?
.reveal_recv(&(0..recv_len))?;
let prove_config = builder.build()?;
let config = builder.build()?;
prover.prove(&prove_config).await?;
prover.prove(&config).await?;
prover.close().await?;
let time_total = time_start.elapsed().as_millis();

View File

@@ -7,9 +7,10 @@ publish = false
[dependencies]
tlsn-harness-core = { workspace = true }
# tlsn-server-fixture = { workspace = true }
charming = { version = "0.6.0", features = ["ssr"] }
charming = { version = "0.5.1", features = ["ssr"] }
csv = "1.3.0"
clap = { workspace = true, features = ["derive", "env"] }
polars = { version = "0.44", features = ["csv", "lazy"] }
itertools = "0.14.0"
toml = { workspace = true }

View File

@@ -1,111 +0,0 @@
# TLSNotary Benchmark Plot Tool
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
## Usage
```bash
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
```
### Arguments
- `<TOML>` - Path to Bench.toml file defining benchmark structure
- `<CSV>...` - One or more CSV files with benchmark results
### Options
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
- Number of labels must match number of CSV files
- `--min-max-band` - Add min/max bands to plots showing variance
- `-h, --help` - Print help information
## Examples
### Single Dataset
```bash
tlsn-harness-plot bench.toml results.csv
```
Generates plots from a single benchmark run.
### Compare Two Runs
```bash
tlsn-harness-plot bench.toml before.csv after.csv \
--labels "Before Optimization" "After Optimization"
```
Overlays two datasets to compare performance improvements.
### Multiple Datasets
```bash
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
--labels "Native" "Browser" "WASM"
```
Compare three different runtime environments.
### With Min/Max Bands
```bash
tlsn-harness-plot bench.toml run1.csv run2.csv \
--labels "Config A" "Config B" \
--min-max-band
```
Shows variance ranges for each dataset.
## Output Files
The tool generates two files per benchmark group:
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
- `<output>.svg` - Static SVG image for documentation
Default output filenames:
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
## Plot Format
Each dataset displays:
- **Solid line** - Total runtime (preprocessing + online phase)
- **Dashed line** - Online phase only
- **Shaded area** (optional) - Min/max variance bands
Different datasets automatically use distinct colors for easy comparison.
## CSV Format
Expected columns in each CSV file:
- `group` - Benchmark group name (must match TOML)
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
- `latency` - Network latency in ms (for latency plots)
- `time_preprocess` - Preprocessing time in ms
- `time_online` - Online phase time in ms
- `time_total` - Total runtime in ms
## TOML Format
The benchmark TOML file defines groups with either:
```toml
[[group]]
name = "my_benchmark"
protocol_latency = 50 # Fixed latency for bandwidth plots
# OR
bandwidth = 10000 # Fixed bandwidth for latency plots
```
All datasets must use the same TOML file to ensure consistent benchmark structure.
## Tips
- Use descriptive labels to make plots self-documenting
- Keep CSV files from the same benchmark configuration for valid comparisons
- Min/max bands are useful for showing stability but can clutter plots with many datasets
- Interactive HTML plots support zooming and hovering for detailed values

View File

@@ -1,18 +1,17 @@
use std::f32;
use charming::{
Chart, HtmlRenderer, ImageRenderer,
Chart, HtmlRenderer,
component::{Axis, Legend, Title},
element::{
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
Trigger,
},
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
series::Line,
theme::Theme,
};
use clap::Parser;
use harness_core::bench::BenchItems;
use polars::prelude::*;
use harness_core::bench::{BenchItems, Measurement};
use itertools::Itertools;
const THEME: Theme = Theme::Default;
#[derive(Parser, Debug)]
#[command(author, version, about)]
@@ -20,131 +19,72 @@ struct Cli {
/// Path to the Bench.toml file with benchmark spec
toml: String,
/// Paths to CSV files with benchmark results (one or more)
csv: Vec<String>,
/// Path to the CSV file with benchmark results
csv: String,
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
#[arg(short, long, num_args = 0..)]
labels: Vec<String>,
/// Prover kind: native or browser
#[arg(short, long, value_enum, default_value = "native")]
prover_kind: ProverKind,
/// Add min/max bands to plots
#[arg(long, default_value_t = false)]
min_max_band: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ProverKind {
Native,
Browser,
}
impl std::fmt::Display for ProverKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProverKind::Native => write!(f, "Native"),
ProverKind::Browser => write!(f, "Browser"),
}
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
if cli.csv.is_empty() {
return Err("At least one CSV file must be provided".into());
}
// Generate labels if not provided
let labels: Vec<String> = if cli.labels.is_empty() {
cli.csv
.iter()
.enumerate()
.map(|(i, _)| format!("Dataset {}", i + 1))
.collect()
} else if cli.labels.len() != cli.csv.len() {
return Err(format!(
"Number of labels ({}) must match number of CSV files ({})",
cli.labels.len(),
cli.csv.len()
)
.into());
} else {
cli.labels.clone()
};
// Load all CSVs and add dataset label
let mut dfs = Vec::new();
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
let mut df = CsvReadOptions::default()
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
.finish()?;
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
df.with_column(label_series)?;
dfs.push(df);
}
// Combine all dataframes
let df = dfs
.into_iter()
.reduce(|acc, df| acc.vstack(&df).unwrap())
.unwrap();
let mut rdr = csv::Reader::from_path(&cli.csv)?;
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
let groups = items.group;
for group in groups {
// Determine which field varies in benches for this group
let benches_in_group: Vec<_> = items
.bench
.iter()
.filter(|b| b.group.as_deref() == Some(&group.name))
.collect();
// Prepare data for plotting.
let all_data: Vec<Measurement> = rdr
.deserialize::<Measurement>()
.collect::<Result<Vec<_>, _>>()?;
if benches_in_group.is_empty() {
continue;
for group in groups {
if group.protocol_latency.is_some() {
let latency = group.protocol_latency.unwrap();
plot_runtime_vs(
&all_data,
cli.min_max_band,
&group.name,
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
"runtime_vs_bandwidth.html",
"Bandwidth (Mbps)",
)?;
}
// Check which field has varying values
let bandwidth_varies = benches_in_group
.windows(2)
.any(|w| w[0].bandwidth != w[1].bandwidth);
let latency_varies = benches_in_group
.windows(2)
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
let download_size_varies = benches_in_group
.windows(2)
.any(|w| w[0].download_size != w[1].download_size);
if download_size_varies {
let upload_size = group.upload_size.unwrap_or(1024);
if group.bandwidth.is_some() {
let bandwidth = group.bandwidth.unwrap();
plot_runtime_vs(
&df,
&labels,
&all_data,
cli.min_max_band,
&group.name,
"download_size",
1.0 / 1024.0, // bytes to KB
"Runtime vs Response Size",
format!("{} bytes upload size", upload_size),
"runtime_vs_download_size",
"Response Size (KB)",
true, // legend on left
)?;
} else if bandwidth_varies {
let latency = group.protocol_latency.unwrap_or(50);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"bandwidth",
1.0 / 1000.0, // Kbps to Mbps
"Runtime vs Bandwidth",
format!("{} ms Latency", latency),
"runtime_vs_bandwidth",
"Bandwidth (Mbps)",
false, // legend on right
)?;
} else if latency_varies {
let bandwidth = group.bandwidth.unwrap_or(1000);
plot_runtime_vs(
&df,
&labels,
cli.min_max_band,
&group.name,
"latency",
1.0,
|r| r.latency as f32,
"Runtime vs Latency",
format!("{} bps bandwidth", bandwidth),
"runtime_vs_latency",
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
"runtime_vs_latency.html",
"Latency (ms)",
true, // legend on left
)?;
}
}
@@ -152,52 +92,84 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
struct DataPoint {
min: f32,
mean: f32,
max: f32,
}
struct Points {
preprocess: DataPoint,
online: DataPoint,
total: DataPoint,
}
#[allow(clippy::too_many_arguments)]
fn plot_runtime_vs(
df: &DataFrame,
labels: &[String],
fn plot_runtime_vs<Fx>(
all_data: &[Measurement],
show_min_max: bool,
group: &str,
x_col: &str,
x_scale: f32,
x_value: Fx,
title: &str,
subtitle: String,
output_file: &str,
x_axis_label: &str,
legend_left: bool,
) -> Result<Chart, Box<dyn std::error::Error>> {
let stats_df = df
.clone()
.lazy()
.filter(col("group").eq(lit(group)))
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
.with_columns([
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
])
.group_by([col("x"), col("dataset_label")])
.agg([
col("preprocess").min().alias("preprocess_min"),
col("preprocess").mean().alias("preprocess_mean"),
col("preprocess").max().alias("preprocess_max"),
col("online").min().alias("online_min"),
col("online").mean().alias("online_mean"),
col("online").max().alias("online_max"),
col("total").min().alias("total_min"),
col("total").mean().alias("total_mean"),
col("total").max().alias("total_max"),
])
.sort(["dataset_label", "x"], Default::default())
.collect()?;
// Build legend entries
let mut legend_data = Vec::new();
for label in labels {
legend_data.push(format!("Total Mean ({})", label));
legend_data.push(format!("Online Mean ({})", label));
) -> Result<Chart, Box<dyn std::error::Error>>
where
Fx: Fn(&Measurement) -> f32,
{
fn data_point(values: &[f32]) -> DataPoint {
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
DataPoint { min, mean, max }
}
let stats: Vec<(f32, Points)> = all_data
.iter()
.filter(|r| r.group.as_deref() == Some(group))
.map(|r| {
(
x_value(r),
r.time_preprocess as f32 / 1000.0, // ms to s
r.time_online as f32 / 1000.0,
r.time_total as f32 / 1000.0,
)
})
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.chunk_by(|entry| entry.0)
.into_iter()
.map(|(x, group)| {
let group_vec: Vec<_> = group.collect();
let preprocess = data_point(
&group_vec
.iter()
.map(|(_, t, _, _)| *t)
.collect::<Vec<f32>>(),
);
let online = data_point(
&group_vec
.iter()
.map(|(_, _, t, _)| *t)
.collect::<Vec<f32>>(),
);
let total = data_point(
&group_vec
.iter()
.map(|(_, _, _, t)| *t)
.collect::<Vec<f32>>(),
);
(
x,
Points {
preprocess,
online,
total,
},
)
})
.collect();
let mut chart = Chart::new()
.title(
Title::new()
@@ -207,6 +179,14 @@ fn plot_runtime_vs(
.subtext_style(TextStyle::new().font_size(16)),
)
.tooltip(Tooltip::new().trigger(Trigger::Axis))
.legend(
Legend::new()
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
.top("80")
.right("110")
.orient(Orient::Vertical)
.item_gap(10),
)
.x_axis(
Axis::new()
.name(x_axis_label)
@@ -225,156 +205,73 @@ fn plot_runtime_vs(
.name_text_style(TextStyle::new().font_size(21)),
);
// Add legend with conditional positioning
let legend = Legend::new()
.data(legend_data)
.top("80")
.orient(Orient::Vertical)
.item_gap(10);
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
let legend = if legend_left {
legend.left("110")
} else {
legend.right("110")
};
chart = chart.legend(legend);
// Define colors for each dataset
let colors = vec![
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
];
for (idx, label) in labels.iter().enumerate() {
let color = colors.get(idx % colors.len()).unwrap();
// Total time - solid line
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Total Mean ({})", label),
"total_mean",
false,
color,
)?;
// Online time - dashed line (same color as total)
chart = add_dataset_series(
&chart,
&stats_df,
label,
&format!("Online Mean ({})", label),
"online_mean",
true,
color,
)?;
if show_min_max {
chart = add_dataset_min_max_band(
&chart,
&stats_df,
label,
&format!("Total Min/Max ({})", label),
"total",
color,
)?;
}
if show_min_max {
chart = add_min_max_band(
chart,
&stats,
"Preprocess Min/Max",
|p| &p.preprocess,
"#ccc",
);
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
}
// Save the chart as HTML file (no theme)
// Save the chart as HTML file.
HtmlRenderer::new(title, 1000, 800)
.save(&chart, &format!("{}.html", output_file))
.unwrap();
// Save SVG with default theme
ImageRenderer::new(1000, 800)
.theme(Theme::Default)
.save(&chart, &format!("{}.svg", output_file))
.unwrap();
// Save SVG with dark theme
ImageRenderer::new(1000, 800)
.theme(Theme::Dark)
.save(&chart, &format!("{}_dark.svg", output_file))
.theme(THEME)
.save(&chart, output_file)
.unwrap();
Ok(chart)
}
fn add_dataset_series(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
series_name: &str,
col_name: &str,
dashed: bool,
color: &str,
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let y = filtered.column(col_name)?.f32()?;
let data: Vec<Vec<f32>> = x
.into_iter()
.zip(y.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let mut line = Line::new()
.name(series_name)
.data(data)
.symbol_size(6)
.item_style(ItemStyle::new().color(color));
let mut line_style = LineStyle::new();
if dashed {
line_style = line_style.type_(LineStyleType::Dashed);
}
line = line.line_style(line_style.color(color));
Ok(chart.clone().series(line))
}
fn add_dataset_min_max_band(
chart: &Chart,
df: &DataFrame,
dataset_label: &str,
fn add_mean_series(
chart: Chart,
stats: &[(f32, Points)],
name: &str,
col_prefix: &str,
color: &str,
) -> Result<Chart, Box<dyn std::error::Error>> {
// Filter for specific dataset
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
let filtered = df.filter(&mask)?;
let x = filtered.column("x")?.f32()?;
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
let max_data: Vec<Vec<f32>> = x
.into_iter()
.zip(max_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.collect();
let min_data: Vec<Vec<f32>> = x
.into_iter()
.zip(min_col.into_iter())
.filter_map(|(x, y)| Some(vec![x?, y?]))
.rev()
.collect();
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
Ok(chart.clone().series(
extract: impl Fn(&Points) -> f32,
) -> Chart {
chart.series(
Line::new()
.name(name)
.data(data)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points)])
.collect(),
)
.symbol_size(6),
)
}
fn add_min_max_band(
chart: Chart,
stats: &[(f32, Points)],
name: &str,
extract: impl Fn(&Points) -> &DataPoint,
color: &str,
) -> Chart {
chart.series(
Line::new()
.name(name)
.data(
stats
.iter()
.map(|(x, points)| vec![*x, extract(points).max])
.chain(
stats
.iter()
.rev()
.map(|(x, points)| vec![*x, extract(points).min]),
)
.collect(),
)
.show_symbol(false)
.line_style(LineStyle::new().opacity(0.0))
.area_style(AreaStyle::new().opacity(0.3).color(color)),
))
)
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -32,13 +32,18 @@ use crate::debug_prelude::*;
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Target {
#[default]
Native,
Browser,
}
impl Default for Target {
fn default() -> Self {
Self::Native
}
}
struct Runner {
network: Network,
server_fixture: ServerFixture,

View File

@@ -1,25 +0,0 @@
#### Bandwidth ####
[[group]]
name = "bandwidth"
protocol_latency = 25
[[bench]]
group = "bandwidth"
bandwidth = 10
[[bench]]
group = "bandwidth"
bandwidth = 50
[[bench]]
group = "bandwidth"
bandwidth = 100
[[bench]]
group = "bandwidth"
bandwidth = 250
[[bench]]
group = "bandwidth"
bandwidth = 1000

View File

@@ -1,37 +0,0 @@
[[group]]
name = "download_size"
protocol_latency = 10
bandwidth = 200
upload-size = 2048
[[bench]]
group = "download_size"
download-size = 1024
[[bench]]
group = "download_size"
download-size = 2048
[[bench]]
group = "download_size"
download-size = 4096
[[bench]]
group = "download_size"
download-size = 8192
[[bench]]
group = "download_size"
download-size = 16384
[[bench]]
group = "download_size"
download-size = 32768
[[bench]]
group = "download_size"
download-size = 65536
[[bench]]
group = "download_size"
download-size = 131072

View File

@@ -1,25 +0,0 @@
#### Latency ####
[[group]]
name = "latency"
bandwidth = 1000
[[bench]]
group = "latency"
protocol_latency = 10
[[bench]]
group = "latency"
protocol_latency = 25
[[bench]]
group = "latency"
protocol_latency = 50
[[bench]]
group = "latency"
protocol_latency = 100
[[bench]]
group = "latency"
protocol_latency = 200

View File

@@ -24,7 +24,7 @@ use std::{
};
#[cfg(feature = "tracing")]
use tracing::{debug, debug_span, trace, warn, Instrument};
use tracing::{debug, debug_span, error, trace, warn, Instrument};
use tls_client::ClientConnection;

View File

@@ -1,6 +1,5 @@
use super::{Backend, BackendError};
use crate::{DecryptMode, EncryptMode, Error};
#[allow(deprecated)]
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead, Payload},
Aes128Gcm,
@@ -508,7 +507,6 @@ impl Encrypter {
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(explicit_nonce);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap();
// ciphertext will have the MAC appended
@@ -570,7 +568,6 @@ impl Decrypter {
let mut nonce = [0u8; 12];
nonce[..4].copy_from_slice(&self.write_iv);
nonce[4..].copy_from_slice(&m.payload.0[0..8]);
#[allow(deprecated)]
let nonce = GenericArray::from_slice(&nonce);
let plaintext = cipher
.decrypt(nonce, aes_payload)

View File

@@ -1,7 +1,7 @@
use std::ops::Range;
use mpz_memory_core::{Vector, binary::U8};
use rangeset::set::RangeSet;
use rangeset::RangeSet;
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct RangeMap<T> {
@@ -77,7 +77,7 @@ where
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
let mut map = Vec::new();
for idx in idx.iter() {
for idx in idx.iter_ranges() {
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
Ok(i) => i,
Err(0) => return None,

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
ProverOutput,
config::prove::ProveConfig,

View File

@@ -12,7 +12,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Call, CallableExt, Vm};
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
use rangeset::{Difference, RangeSet, Union};
use tlsn_core::transcript::Record;
use crate::transcript_internal::ReferenceMap;
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal).into_set()
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
}
} else {
let private = commit.difference(reveal).into_set();
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
commit.clone()
} else {
// The plaintext is only partially revealed, so we need to authenticate in ZK.
commit.union(reveal).into_set()
commit.union(reveal)
};
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
ciphertext,
})
} else {
let private = commit.difference(reveal).into_set();
let private = commit.difference(reveal);
for (_, slice) in plaintext_refs
.index(&private)
.expect("all ranges are allocated")
@@ -175,13 +175,15 @@ fn alloc_plaintext(
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = plaintext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_ciphertext<'a>(
@@ -210,13 +212,15 @@ fn alloc_ciphertext<'a>(
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
let mut pos = 0;
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
})))
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
move |range| {
let chunk = ciphertext
.get(pos..pos + range.len())
.expect("length was checked");
pos += range.len();
(range.start, chunk)
},
)))
}
fn alloc_keystream<'a>(
@@ -229,7 +233,7 @@ fn alloc_keystream<'a>(
let mut keystream = Vec::new();
let mut pos = 0;
let mut range_iter = ranges.iter();
let mut range_iter = ranges.iter_ranges();
let mut current_range = range_iter.next();
for record in records {
let mut explicit_nonce = None;
@@ -504,7 +508,7 @@ mod tests {
for record in records {
let mut record_keystream = vec![0u8; record.len];
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
for mut range in ranges.iter() {
for mut range in ranges.iter_ranges() {
range.start = range.start.max(pos);
range.end = range.end.min(pos + record.len);
if range.start < range.end {

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
correlated::{Delta, Key, Mac},
};
use rand::Rng;
use rangeset::set::RangeSet;
use rangeset::RangeSet;
use serde::{Deserialize, Serialize};
use serio::{SinkExt, stream::IoStreamExt};
use tlsn_core::{

View File

@@ -9,7 +9,7 @@ use mpz_memory_core::{
binary::{Binary, U8},
};
use mpz_vm_core::{Vm, VmError, prelude::*};
use rangeset::set::RangeSet;
use rangeset::RangeSet;
use tlsn_core::{
hash::{Blinder, Hash, HashAlgId, TypedHash},
transcript::{
@@ -155,7 +155,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
}
@@ -176,7 +176,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;
@@ -201,7 +201,7 @@ fn hash_commit_inner(
Direction::Received => &refs.recv,
};
for range in idx.iter() {
for range in idx.iter_ranges() {
hasher
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
.map_err(HashCommitError::hasher)?;

View File

@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
use mpz_common::Context;
use mpz_memory_core::binary::Binary;
use mpz_vm_core::Vm;
use rangeset::set::RangeSet;
use rangeset::{RangeSet, UnionMut};
use tlsn_core::{
VerifierOutput,
config::prove::ProveRequest,

View File

@@ -1,5 +1,5 @@
use futures::{AsyncReadExt, AsyncWriteExt};
use rangeset::set::RangeSet;
use rangeset::RangeSet;
use tlsn::{
config::{
prove::ProveConfig,
@@ -51,11 +51,19 @@ async fn test() {
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
assert!(!partial_transcript.is_complete());
assert_eq!(
partial_transcript.sent_authed().iter().next().unwrap(),
partial_transcript
.sent_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);
assert_eq!(
partial_transcript.received_authed().iter().next().unwrap(),
partial_transcript
.received_authed()
.iter_ranges()
.next()
.unwrap(),
0..10
);

View File

@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
Self {
sent: value.sent_unsafe().to_vec(),
sent_authed: value.sent_authed().iter().collect(),
sent_authed: value.sent_authed().iter_ranges().collect(),
recv: value.received_unsafe().to_vec(),
recv_authed: value.received_authed().iter().collect(),
recv_authed: value.received_authed().iter_ranges().collect(),
}
}
}