mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-12 08:08:29 -05:00
Compare commits
12 Commits
feat/pest_
...
plot_py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b76775fc7c | ||
|
|
72041d1f07 | ||
|
|
ac1df8fc75 | ||
|
|
3cb7c5c0b4 | ||
|
|
b41d678829 | ||
|
|
1ebefa27d8 | ||
|
|
4fe5c1defd | ||
|
|
0e8e547300 | ||
|
|
22cc88907a | ||
|
|
cec4756e0e | ||
|
|
0919e1f2b3 | ||
|
|
43b9f57e1f |
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -21,7 +21,8 @@ env:
|
||||
# - https://github.com/privacy-ethereum/mpz/issues/178
|
||||
# 32 seems to be big enough for the foreseeable future
|
||||
RAYON_NUM_THREADS: 32
|
||||
RUST_VERSION: 1.90.0
|
||||
RUST_VERSION: 1.92.0
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
|
||||
2531
Cargo.lock
generated
2531
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
35
Cargo.toml
35
Cargo.toml
@@ -66,26 +66,27 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn = { path = "crates/tlsn" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "9c343f8" }
|
||||
|
||||
rangeset = { version = "0.2" }
|
||||
rangeset = { version = "0.4" }
|
||||
serio = { version = "0.2" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
uid-mux = { version = "0.2" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6f1a934" }
|
||||
|
||||
aead = { version = "0.4" }
|
||||
aes = { version = "0.8" }
|
||||
|
||||
@@ -15,7 +15,7 @@ use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use rangeset::{Difference, RangeSet, UnionMut};
|
||||
use rangeset::{ops::Set, set::RangeSet};
|
||||
use tokio::sync::{Mutex, MutexGuard, OwnedMutexGuard};
|
||||
|
||||
type Error = DeapError;
|
||||
@@ -210,10 +210,12 @@ where
|
||||
}
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
let slice_range = slice.to_range();
|
||||
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
|
||||
let input_minus_follower = slice_range.difference(&self.follower_input_ranges);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
for input in input_minus_follower {
|
||||
zk.commit_raw(
|
||||
self.memory_map
|
||||
.try_get(Slice::from_range_unchecked(input))?,
|
||||
@@ -266,7 +268,7 @@ where
|
||||
mpc.mark_private_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
}
|
||||
@@ -282,7 +284,7 @@ where
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_input_ranges.union_mut(slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
Role::Follower => {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_vm_core::{memory::Slice, VmError};
|
||||
use rangeset::Subset;
|
||||
use rangeset::ops::Set;
|
||||
|
||||
/// A mapping between the memories of the MPC and ZK VMs.
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
@@ -27,12 +27,6 @@ tlsn-data-fixtures = { workspace = true, optional = true }
|
||||
tlsn-tls-core = { workspace = true, features = ["serde"] }
|
||||
tlsn-utils = { workspace = true }
|
||||
rangeset = { workspace = true, features = ["serde"] }
|
||||
mpz-circuits = { workspace = true }
|
||||
pest = "*"
|
||||
pest_derive = "*"
|
||||
pest_meta = "*"
|
||||
|
||||
|
||||
|
||||
aead = { workspace = true, features = ["alloc"], optional = true }
|
||||
aes-gcm = { workspace = true, optional = true }
|
||||
@@ -65,5 +59,7 @@ generic-array = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rstest = { workspace = true }
|
||||
tlsn-core = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-attestation = { workspace = true, features = ["fixtures"] }
|
||||
tlsn-data-fixtures = { workspace = true }
|
||||
webpki-root-certs = { workspace = true }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Proving configuration.
|
||||
|
||||
use rangeset::{RangeSet, ToRangeSet, UnionMut};
|
||||
use rangeset::set::{RangeSet, ToRangeSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::transcript::{Direction, Transcript, TranscriptCommitConfig, TranscriptCommitRequest};
|
||||
|
||||
@@ -185,22 +185,17 @@ impl MpcTlsConfigBuilder {
|
||||
///
|
||||
/// Provides optimization options to adapt the protocol to different network
|
||||
/// situations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
|
||||
pub enum NetworkSetting {
|
||||
/// Reduces network round-trips at the expense of consuming more network
|
||||
/// bandwidth.
|
||||
Bandwidth,
|
||||
/// Reduces network bandwidth utilization at the expense of more network
|
||||
/// round-trips.
|
||||
#[default]
|
||||
Latency,
|
||||
}
|
||||
|
||||
impl Default for NetworkSetting {
|
||||
fn default() -> Self {
|
||||
Self::Latency
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`MpcTlsConfig`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
|
||||
pub(crate) struct FmtRangeSet<'a>(pub &'a RangeSet<usize>);
|
||||
|
||||
impl<'a> std::fmt::Display for FmtRangeSet<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("{")?;
|
||||
for range in self.0.iter_ranges() {
|
||||
for range in self.0.iter() {
|
||||
write!(f, "{}..{}", range.start, range.end)?;
|
||||
if range.end < self.0.end().unwrap_or(0) {
|
||||
f.write_str(", ")?;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use aead::Payload as AeadPayload;
|
||||
use aes_gcm::{aead::Aead, Aes128Gcm, NewAead};
|
||||
#[allow(deprecated)]
|
||||
use generic_array::GenericArray;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use tls_core::msgs::{
|
||||
@@ -180,6 +181,7 @@ fn aes_gcm_encrypt(
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&iv);
|
||||
nonce[4..].copy_from_slice(&explicit_nonce);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let cipher = Aes128Gcm::new_from_slice(&key).unwrap();
|
||||
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::predicates::Pred;
|
||||
use pest::{
|
||||
iterators::{Pair, Pairs},
|
||||
Parser,
|
||||
};
|
||||
use pest_derive::Parser;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[grammar = "expr.pest"]
|
||||
struct ExprParser;
|
||||
|
||||
fn parse_expr(input: &str) -> Result<Pred, pest::error::Error<Rule>> {
|
||||
let mut pairs = ExprParser::parse(Rule::expr, input)?;
|
||||
Ok(build_expr(pairs.next().unwrap()))
|
||||
}
|
||||
|
||||
fn build_expr(pair: Pair<Rule>) -> Pred {
|
||||
match pair.as_rule() {
|
||||
Rule::expr | Rule::or_expr => build_left_assoc(pair.into_inner(), Rule::and_expr, Pred::Or),
|
||||
Rule::and_expr => build_left_assoc(pair.into_inner(), Rule::not_expr, Pred::And),
|
||||
Rule::not_expr => {
|
||||
// NOT* cmp
|
||||
let mut inner = pair.into_inner(); // possibly multiple NOT then a cmp
|
||||
// Count NOTs, then parse cmp
|
||||
let mut not_count = 0;
|
||||
let mut rest = Vec::new();
|
||||
for p in inner {
|
||||
match p.as_rule() {
|
||||
Rule::NOT => not_count += 1,
|
||||
_ => {
|
||||
rest.push(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut node = build_cmp(rest.into_iter().next().expect("cmp missing"));
|
||||
if not_count % 2 == 1 {
|
||||
node = Pred::Not(Box::new(node));
|
||||
}
|
||||
node
|
||||
}
|
||||
Rule::cmp => build_cmp(pair),
|
||||
Rule::primary => build_expr(pair.into_inner().next().unwrap()),
|
||||
Rule::paren => build_expr(pair.into_inner().next().unwrap()),
|
||||
_ => unreachable!("unexpected rule: {:?}", pair.as_rule()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_left_assoc(
|
||||
mut inner: Pairs<Rule>,
|
||||
unit_rule: Rule,
|
||||
mk_node: impl Fn(Vec<Pred>) -> Pred,
|
||||
) -> Pred {
|
||||
// pattern: unit (OP unit)*
|
||||
let mut nodes = Vec::new();
|
||||
// First unit
|
||||
if let Some(first) = inner.next() {
|
||||
assert_eq!(first.as_rule(), unit_rule);
|
||||
nodes.push(build_expr(first));
|
||||
}
|
||||
// Remaining are: OP unit pairs; we only collect the units and wrap later.
|
||||
while let Some(next) = inner.next() {
|
||||
// next is the operator token pair (AND/OR), skip it
|
||||
// then the unit:
|
||||
if let Some(unit) = inner.next() {
|
||||
assert_eq!(unit.as_rule(), unit_rule);
|
||||
nodes.push(build_expr(unit));
|
||||
}
|
||||
}
|
||||
if nodes.len() == 1 {
|
||||
nodes.pop().unwrap()
|
||||
} else {
|
||||
mk_node(nodes)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cmp(pair: Pair<Rule>) -> Pred {
|
||||
// cmp: primary (cmp_op primary)?
|
||||
let mut inner = pair.into_inner();
|
||||
let lhs = inner.next().unwrap();
|
||||
let lhs_term = parse_term(lhs);
|
||||
if let Some(op_pair) = inner.next() {
|
||||
let op = match op_pair.as_str() {
|
||||
"==" => CmpOp::Eq,
|
||||
"!=" => CmpOp::Ne,
|
||||
"<" => CmpOp::Lt,
|
||||
"<=" => CmpOp::Lte,
|
||||
">" => CmpOp::Gt,
|
||||
">=" => CmpOp::Gte,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let rhs = parse_term(inner.next().unwrap());
|
||||
// Map to your Atom constraint form (LHS must be x[idx]):
|
||||
let (index, rhs_val) = match (lhs_term, rhs) {
|
||||
(Term::Idx(i), Term::Const(c)) => (i, Rhs::Const(c)),
|
||||
(Term::Idx(i1), Term::Idx(i2)) => (i1, Rhs::Idx(i2)),
|
||||
// If you want to allow const OP idx or const OP const, handle here (flip, etc.)
|
||||
other => panic!("unsupported comparison pattern: {:?}", other),
|
||||
};
|
||||
Pred::Atom(Atom {
|
||||
index,
|
||||
op,
|
||||
rhs: rhs_val,
|
||||
})
|
||||
} else {
|
||||
// A bare primary is treated as a boolean atom; you can decide policy.
|
||||
// Here we treat "x[i]" as (x[i] != 0) and const as (const != 0).
|
||||
match lhs_term {
|
||||
Term::Idx(i) => Pred::Atom(Atom {
|
||||
index: i,
|
||||
op: CmpOp::Ne,
|
||||
rhs: Rhs::Const(0),
|
||||
}),
|
||||
Term::Const(c) => {
|
||||
if c != 0 {
|
||||
Pred::Or(vec![])
|
||||
} else {
|
||||
Pred::And(vec![])
|
||||
} // true/false constants if you add Const
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Term {
|
||||
Idx(usize),
|
||||
Const(u8),
|
||||
}
|
||||
|
||||
fn parse_term(pair: Pair<Rule>) -> Term {
|
||||
match pair.as_rule() {
|
||||
Rule::atom => parse_term(pair.into_inner().next().unwrap()),
|
||||
Rule::byte_idx => {
|
||||
// "x" "[" number "]"
|
||||
let mut i = pair.into_inner();
|
||||
let num = i.find(|p| p.as_rule() == Rule::number).unwrap();
|
||||
Term::Idx(num.as_str().parse::<usize>().unwrap())
|
||||
}
|
||||
Rule::byte_const => {
|
||||
let n = pair.into_inner().next().unwrap(); // number
|
||||
Term::Const(n.as_str().parse::<u8>().unwrap())
|
||||
}
|
||||
Rule::paren => parse_term(pair.into_inner().next().unwrap()),
|
||||
Rule::primary => parse_term(pair.into_inner().next().unwrap()),
|
||||
_ => unreachable!("term {:?}", pair.as_rule()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_and() {
|
||||
let pred = parse_expr("x[100] < x[300] && x[200] == 2 || ! (x[5] >= 57)").unwrap();
|
||||
// `pred` is a Pred::Or with an And on the left and a Not on the right,
|
||||
// with Atoms inside.
|
||||
}
|
||||
}
|
||||
@@ -296,14 +296,14 @@ mod sha2 {
|
||||
fn hash(&self, data: &[u8]) -> super::Hash {
|
||||
let mut hasher = ::sha2::Sha256::default();
|
||||
hasher.update(data);
|
||||
super::Hash::new(hasher.finalize().as_slice())
|
||||
super::Hash::new(hasher.finalize().as_ref())
|
||||
}
|
||||
|
||||
fn hash_prefixed(&self, prefix: &[u8], data: &[u8]) -> super::Hash {
|
||||
let mut hasher = ::sha2::Sha256::default();
|
||||
hasher.update(prefix);
|
||||
hasher.update(data);
|
||||
super::Hash::new(hasher.finalize().as_slice())
|
||||
super::Hash::new(hasher.finalize().as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
// pest. The Elegant Parser
|
||||
// Copyright (c) 2018 Dragoș Tiselice
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0
|
||||
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
|
||||
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||
// option. All files in the project carrying such notice may not be copied,
|
||||
// modified, or distributed except according to those terms.
|
||||
|
||||
//! A parser for JSON file.
|
||||
//!
|
||||
//! And this is a example for JSON parser.
|
||||
json = _{ SOI ~ value ~ eoi }
|
||||
eoi = _{ !ANY }
|
||||
|
||||
/// Matches object, e.g.: `{ "foo": "bar" }`
|
||||
/// Foobar
|
||||
object = { "{" ~ pair ~ (pair)* ~ "}" | "{" ~ "}" }
|
||||
pair = { quoted_string ~ ":" ~ value ~ (",")? }
|
||||
|
||||
array = { "[" ~ value ~ ("," ~ value)* ~ "]" | "[" ~ "]" }
|
||||
|
||||
//////////////////////
|
||||
/// Matches value, e.g.: `"foo"`, `42`, `true`, `null`, `[]`, `{}`.
|
||||
//////////////////////
|
||||
value = _{ quoted_string | number | object | array | bool | null }
|
||||
|
||||
quoted_string = _{ "\"" ~ string ~ "\"" }
|
||||
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ string)? }
|
||||
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
|
||||
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }
|
||||
|
||||
number = @{ "-"? ~ int ~ ("." ~ ASCII_DIGIT+ ~ exp? | exp)? }
|
||||
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }
|
||||
exp = @{ ("E" | "e") ~ ("+" | "-")? ~ ASCII_DIGIT+ }
|
||||
|
||||
bool = { "true" | "false" }
|
||||
|
||||
null = { "null" }
|
||||
|
||||
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
|
||||
@@ -1,760 +0,0 @@
|
||||
//!
|
||||
use crate::predicates::Pred;
|
||||
use pest::{
|
||||
iterators::{Pair, Pairs},
|
||||
Parser,
|
||||
};
|
||||
use pest_derive::Parser;
|
||||
use pest_meta::{ast, parser as meta, parser::consume_rules, validator};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use core::panic;
|
||||
use std::cmp::{max, min};
|
||||
|
||||
use crate::{
|
||||
config::prove::ProveConfig,
|
||||
predicates::{eval_pred, is_unicode, Atom, CmpOp, Compiler, Rhs},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use mpz_circuits::{
|
||||
evaluate,
|
||||
ops::{all, any},
|
||||
};
|
||||
use pest_meta::ast::Expr;
|
||||
use rangeset::RangeSet;
|
||||
|
||||
const MAX_LEN: usize = 999_999;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Ex {
|
||||
RepEx(Rep),
|
||||
RepExactEx(RepExact),
|
||||
SeqEx(Seq),
|
||||
StrEx(Str),
|
||||
ChoiceEx(Choice),
|
||||
NegPredEx(NegPred),
|
||||
OptEx(Opt),
|
||||
// An expression which must be replaced with a copy of the rule.
|
||||
NestedEx,
|
||||
#[allow(non_camel_case_types)]
|
||||
ASCII_NONZERO_DIGIT,
|
||||
#[allow(non_camel_case_types)]
|
||||
ASCII_DIGIT,
|
||||
#[allow(non_camel_case_types)]
|
||||
// A single Unicode character
|
||||
ANY,
|
||||
#[allow(non_camel_case_types)]
|
||||
ASCII_HEX_DIGIT,
|
||||
}
|
||||
|
||||
impl Ex {
|
||||
fn min_len(&self) -> usize {
|
||||
match self {
|
||||
Ex::RepEx(e) => 0,
|
||||
Ex::RepExactEx(e) => e.0 .0.min_len() * e.0 .1 as usize,
|
||||
Ex::StrEx(e) => e.0.len(),
|
||||
Ex::SeqEx(e) => e.0.min_len() + e.1.min_len(),
|
||||
Ex::ChoiceEx(e) => min(e.0.min_len(), e.1.min_len()),
|
||||
Ex::NegPredEx(e) => 0,
|
||||
Ex::ASCII_NONZERO_DIGIT => 1,
|
||||
Ex::ASCII_DIGIT => 1,
|
||||
Ex::ANY => 1,
|
||||
Ex::ASCII_HEX_DIGIT => 1,
|
||||
Ex::OptEx(e) => 0,
|
||||
Ex::NestedEx => 0,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn max_len(&self) -> usize {
|
||||
match self {
|
||||
Ex::RepEx(e) => MAX_LEN,
|
||||
Ex::RepExactEx(e) => e.0 .0.max_len() * e.0 .1 as usize,
|
||||
Ex::StrEx(e) => e.0.len(),
|
||||
Ex::SeqEx(e) => e.0.max_len() + e.1.max_len(),
|
||||
Ex::ChoiceEx(e) => max(e.0.max_len(), e.1.max_len()),
|
||||
Ex::NegPredEx(e) => 0,
|
||||
Ex::ASCII_NONZERO_DIGIT => 1,
|
||||
Ex::ASCII_DIGIT => 1,
|
||||
Ex::ANY => 4,
|
||||
Ex::ASCII_HEX_DIGIT => 1,
|
||||
Ex::OptEx(e) => e.0.max_len(),
|
||||
Ex::NestedEx => 0,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Rep(Box<Ex>);
|
||||
#[derive(Debug, Clone)]
|
||||
struct RepExact((Box<Ex>, u32));
|
||||
#[derive(Debug, Clone)]
|
||||
struct Str(String);
|
||||
#[derive(Debug, Clone)]
|
||||
struct Seq(Box<Ex>, Box<Ex>);
|
||||
#[derive(Debug, Clone)]
|
||||
struct Choice(Box<Ex>, Box<Ex>);
|
||||
#[derive(Debug, Clone)]
|
||||
struct NegPred(Box<Ex>);
|
||||
#[derive(Debug, Clone)]
|
||||
struct Opt(Box<Ex>);
|
||||
|
||||
struct Rule {
|
||||
name: String,
|
||||
pub ex: Ex,
|
||||
}
|
||||
|
||||
/// Builds the rules, returning the final expression.
|
||||
fn build_rules(ast_rules: &[ast::Rule]) -> Ex {
|
||||
let mut rules = Vec::new();
|
||||
// build from the bottom up
|
||||
let iter = ast_rules.iter().rev();
|
||||
for r in iter {
|
||||
println!("building rule with name {:?}", r.name);
|
||||
|
||||
let ex = build_expr(&r.expr, &rules, &r.name, false);
|
||||
|
||||
// TODO deal with recursive rules
|
||||
|
||||
rules.push(Rule {
|
||||
name: r.name.clone(),
|
||||
ex,
|
||||
});
|
||||
}
|
||||
let ex = rules.last().unwrap().ex.clone();
|
||||
ex
|
||||
}
|
||||
|
||||
/// Builds expression from pest expression.
|
||||
/// passes in current rule's name to deal with recursion.
|
||||
/// depth is used to prevent infinite recursion.
|
||||
fn build_expr(exp: &Expr, rules: &[Rule], this_name: &String, is_nested: bool) -> Ex {
|
||||
match exp {
|
||||
Expr::Rep(exp) => {
|
||||
Ex::RepEx(Rep(Box::new(build_expr(exp, rules, this_name, is_nested))))
|
||||
}
|
||||
Expr::RepExact(exp, count) => Ex::RepExactEx(RepExact((
|
||||
Box::new(build_expr(exp, rules, this_name, is_nested)),
|
||||
*count,
|
||||
))),
|
||||
Expr::Str(str) => Ex::StrEx(Str(str.clone())),
|
||||
Expr::NegPred(exp) => Ex::NegPredEx(NegPred(Box::new(build_expr(
|
||||
exp, rules, this_name, is_nested,
|
||||
)))),
|
||||
Expr::Seq(a, b) => {
|
||||
//
|
||||
let a = build_expr(a, rules, this_name, is_nested);
|
||||
Ex::SeqEx(Seq(
|
||||
Box::new(a),
|
||||
Box::new(build_expr(b, rules, this_name, is_nested)),
|
||||
))
|
||||
}
|
||||
Expr::Choice(a, b) => Ex::ChoiceEx(Choice(
|
||||
Box::new(build_expr(a, rules, this_name, is_nested)),
|
||||
Box::new(build_expr(b, rules, this_name, is_nested)),
|
||||
)),
|
||||
Expr::Opt(exp) => {
|
||||
Ex::OptEx(Opt(Box::new(build_expr(exp, rules, this_name, is_nested))))
|
||||
}
|
||||
Expr::Ident(ident) => {
|
||||
let ex = match ident.as_str() {
|
||||
"ASCII_NONZERO_DIGIT" => Ex::ASCII_NONZERO_DIGIT,
|
||||
"ASCII_DIGIT" => Ex::ASCII_DIGIT,
|
||||
"ANY" => Ex::ANY,
|
||||
"ASCII_HEX_DIGIT" => Ex::ASCII_HEX_DIGIT,
|
||||
_ => {
|
||||
if *ident == *this_name {
|
||||
return Ex::NestedEx;
|
||||
}
|
||||
|
||||
for rule in rules {
|
||||
return rule.ex.clone();
|
||||
}
|
||||
panic!("couldnt find rule {:?}", ident);
|
||||
}
|
||||
};
|
||||
ex
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
// This method must be called when we know that there is enough
|
||||
// data remained starting from the offset to match the expression
|
||||
// at least once.
|
||||
//
|
||||
// returns the predicate and the offset from which the next expression
|
||||
// should be matched.
|
||||
// Returns multiple predicates if the expression caused multiple branches.
|
||||
// A top level expr always returns a single predicate, in which all branches
|
||||
// are coalesced.
|
||||
fn expr_to_pred(
|
||||
exp: &Ex,
|
||||
offset: usize,
|
||||
data_len: usize,
|
||||
is_top_level: bool,
|
||||
) -> Vec<(Pred, usize)> {
|
||||
// if is_top_level {
|
||||
// println!("top level exps {:?}", exp);
|
||||
// } else {
|
||||
// println!("Non-top level exps {:?}", exp);
|
||||
// }
|
||||
match exp {
|
||||
Ex::SeqEx(s) => {
|
||||
let a = &s.0;
|
||||
let b = &s.1;
|
||||
|
||||
if is_top_level && (offset + a.max_len() + b.max_len() < data_len) {
|
||||
panic!();
|
||||
}
|
||||
if offset + a.min_len() + b.min_len() > data_len {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// The first expression must not try to match in the
|
||||
// data of the next expression
|
||||
let pred1 = expr_to_pred(a, offset, data_len - b.min_len(), false);
|
||||
|
||||
// interlace all branches
|
||||
let mut interlaced = Vec::new();
|
||||
|
||||
for (p1, offset) in pred1.iter() {
|
||||
// if the seq expr was top-level, the 2nd expr becomes top-level
|
||||
let mut pred2 = expr_to_pred(b, *offset, data_len, is_top_level);
|
||||
for (p2, offser_inner) in pred2.iter() {
|
||||
let pred = Pred::And(vec![p1.clone(), p2.clone()]);
|
||||
interlaced.push((pred, *offser_inner));
|
||||
}
|
||||
}
|
||||
|
||||
if is_top_level {
|
||||
// coalesce all branches
|
||||
let preds: Vec<Pred> = interlaced.into_iter().map(|(a, _b)| a).collect();
|
||||
if preds.len() == 1 {
|
||||
vec![(preds[0].clone(), 0)]
|
||||
} else {
|
||||
vec![(Pred::Or(preds), 0)]
|
||||
}
|
||||
} else {
|
||||
interlaced
|
||||
}
|
||||
}
|
||||
Ex::ChoiceEx(s) => {
|
||||
let a = &s.0;
|
||||
let b = &s.1;
|
||||
|
||||
let mut skip_a = false;
|
||||
let mut skip_b = false;
|
||||
|
||||
if is_top_level {
|
||||
if offset + a.max_len() != data_len {
|
||||
skip_a = true
|
||||
}
|
||||
if offset + b.max_len() != data_len {
|
||||
skip_b = true;
|
||||
}
|
||||
} else {
|
||||
// if not top level, we may skip an expression when it will
|
||||
// overflow the data len
|
||||
if offset + a.min_len() > data_len {
|
||||
skip_a = true
|
||||
}
|
||||
if offset + b.min_len() > data_len {
|
||||
skip_b = true
|
||||
}
|
||||
}
|
||||
|
||||
if skip_a && skip_b {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let mut preds_a = Vec::new();
|
||||
let mut preds_b = Vec::new();
|
||||
|
||||
if !skip_a {
|
||||
preds_a = expr_to_pred(a, offset, data_len, is_top_level);
|
||||
}
|
||||
if !skip_b {
|
||||
preds_b = expr_to_pred(b, offset, data_len, is_top_level);
|
||||
}
|
||||
|
||||
// combine all branches
|
||||
let mut combined = Vec::new();
|
||||
if preds_a.is_empty() {
|
||||
combined = preds_b.clone();
|
||||
} else if preds_b.is_empty() {
|
||||
combined = preds_a.clone();
|
||||
} else {
|
||||
assert!(!(preds_a.is_empty() && preds_b.is_empty()));
|
||||
|
||||
combined.append(&mut preds_a);
|
||||
combined.append(&mut preds_b);
|
||||
}
|
||||
|
||||
if is_top_level {
|
||||
// coalesce all branches
|
||||
let preds: Vec<Pred> = combined.into_iter().map(|(a, _b)| a).collect();
|
||||
if preds.len() == 1 {
|
||||
vec![(preds[0].clone(), 0)]
|
||||
} else {
|
||||
vec![(Pred::Or(preds), 0)]
|
||||
}
|
||||
} else {
|
||||
combined
|
||||
}
|
||||
}
|
||||
Ex::RepEx(r) => {
|
||||
let e = &r.0;
|
||||
|
||||
if offset + e.min_len() > data_len {
|
||||
if is_top_level {
|
||||
panic!();
|
||||
}
|
||||
// zero matches
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mut interlaced = Vec::new();
|
||||
|
||||
let mut preds = expr_to_pred(&e, offset, data_len, false);
|
||||
|
||||
// for (i, (pred, depth)) in preds.iter().enumerate() {
|
||||
// println!("preds[{i}] (depth {depth}):");
|
||||
// println!("{pred}");
|
||||
// }
|
||||
|
||||
// Append single matches.
|
||||
interlaced.append(&mut preds.clone());
|
||||
|
||||
let mut was_found = true;
|
||||
|
||||
while was_found {
|
||||
was_found = false;
|
||||
|
||||
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
|
||||
if offset_outer + e.min_len() > data_len {
|
||||
// cannot match any more
|
||||
continue;
|
||||
}
|
||||
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
|
||||
|
||||
// for (i, (pred, depth)) in preds_inner.iter().enumerate() {
|
||||
// println!("preds[{i}] (depth {depth}):");
|
||||
// println!("{pred}");
|
||||
// }
|
||||
for (pred_inner, offset_inner) in preds_inner {
|
||||
let pred = (
|
||||
Pred::And(vec![pred_outer.clone(), pred_inner]),
|
||||
offset_inner,
|
||||
);
|
||||
preds.push(pred);
|
||||
was_found = true;
|
||||
}
|
||||
}
|
||||
interlaced.append(&mut preds.clone());
|
||||
}
|
||||
|
||||
// for (i, (pred, depth)) in interlaced.iter().enumerate() {
|
||||
// println!("preds[{i}] (depth {depth}):");
|
||||
// println!("{pred}");
|
||||
// }
|
||||
|
||||
if is_top_level {
|
||||
// drop all branches which do not match exactly at the data length
|
||||
// border and coalesce the rest
|
||||
let preds: Vec<Pred> = interlaced
|
||||
.into_iter()
|
||||
.filter(|(_a, b)| *b == data_len)
|
||||
.map(|(a, _b)| a)
|
||||
.collect();
|
||||
if preds.is_empty() {
|
||||
panic!()
|
||||
}
|
||||
if preds.len() == 1 {
|
||||
vec![(preds[0].clone(), 0)]
|
||||
} else {
|
||||
// coalesce all branches
|
||||
vec![(Pred::Or(preds), 0)]
|
||||
}
|
||||
} else {
|
||||
interlaced
|
||||
}
|
||||
}
|
||||
Ex::RepExactEx(r) => {
|
||||
let e = &r.0 .0;
|
||||
let count = r.0 .1;
|
||||
assert!(count > 0);
|
||||
|
||||
if is_top_level && (offset + e.max_len() * count as usize <= data_len) {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let mut preds = expr_to_pred(&e, offset, data_len, false);
|
||||
|
||||
for i in 1..count {
|
||||
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
|
||||
if offset_outer + e.min_len() > data_len {
|
||||
// cannot match any more
|
||||
continue;
|
||||
}
|
||||
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
|
||||
for (pred_inner, offset_inner) in preds_inner {
|
||||
let pred = (
|
||||
Pred::And(vec![pred_outer.clone(), pred_inner]),
|
||||
offset_inner,
|
||||
);
|
||||
preds.push(pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_top_level {
|
||||
// drop all branches which do not match exactly at the data length
|
||||
// border and coalesce the rest
|
||||
let preds: Vec<Pred> = preds
|
||||
.into_iter()
|
||||
.filter(|(_a, b)| *b != data_len)
|
||||
.map(|(a, _b)| a)
|
||||
.collect();
|
||||
if preds.is_empty() {
|
||||
panic!()
|
||||
}
|
||||
|
||||
if preds.len() == 1 {
|
||||
vec![(preds[0].clone(), 0)]
|
||||
} else {
|
||||
// coalesce all branches
|
||||
vec![(Pred::Or(preds), 0)]
|
||||
}
|
||||
} else {
|
||||
preds
|
||||
}
|
||||
}
|
||||
Ex::NegPredEx(e) => {
|
||||
assert!(offset <= data_len);
|
||||
if offset == data_len {
|
||||
// the internal expression cannot be match since there is no data left,
|
||||
// this means that the negative expression matched
|
||||
if is_top_level {
|
||||
panic!("always true predicate doesnt make sense")
|
||||
}
|
||||
|
||||
// TODO this is hacky.
|
||||
return vec![(Pred::True, offset)];
|
||||
}
|
||||
|
||||
let e = &e.0;
|
||||
let preds = expr_to_pred(&e, offset, data_len, is_top_level);
|
||||
|
||||
let preds: Vec<Pred> = preds.into_iter().map(|(a, _b)| a).collect();
|
||||
let len = preds.len();
|
||||
|
||||
// coalesce all branches, offset doesnt matter since those
|
||||
// offset will never be used anymore.
|
||||
let pred = if preds.len() == 1 {
|
||||
Pred::Not(Box::new(preds[0].clone()))
|
||||
} else {
|
||||
Pred::Not(Box::new(Pred::Or(preds)))
|
||||
};
|
||||
|
||||
if is_top_level && len == 0 {
|
||||
panic!()
|
||||
}
|
||||
|
||||
// all offset if negative predicate are ignored since no matching
|
||||
// will be done from those offsets.
|
||||
vec![(pred, offset)]
|
||||
}
|
||||
Ex::OptEx(e) => {
|
||||
let e = &e.0;
|
||||
|
||||
if is_top_level {
|
||||
return vec![(Pred::True, 0)];
|
||||
}
|
||||
|
||||
// add an always-matching branch
|
||||
let mut preds = vec![(Pred::True, offset)];
|
||||
|
||||
if e.min_len() + offset <= data_len {
|
||||
// try to match only if there is enough data
|
||||
let mut p = expr_to_pred(&e, offset, data_len, is_top_level);
|
||||
preds.append(&mut p);
|
||||
}
|
||||
|
||||
preds
|
||||
}
|
||||
Ex::StrEx(s) => {
|
||||
if is_top_level && offset + s.0.len() != data_len {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let mut preds = Vec::new();
|
||||
for (idx, byte) in s.0.clone().into_bytes().iter().enumerate() {
|
||||
let a = Atom {
|
||||
index: offset + idx,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Const(*byte),
|
||||
};
|
||||
preds.push(Pred::Atom(a));
|
||||
}
|
||||
|
||||
if preds.len() == 1 {
|
||||
vec![(preds[0].clone(), offset + s.0.len())]
|
||||
} else {
|
||||
vec![(Pred::And(preds), offset + s.0.len())]
|
||||
}
|
||||
}
|
||||
Ex::ASCII_NONZERO_DIGIT => {
|
||||
if is_top_level && (offset + 1 != data_len) {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: offset,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(49u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: offset,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(57u8),
|
||||
});
|
||||
vec![(Pred::And(vec![gte, lte]), offset + 1)]
|
||||
}
|
||||
Ex::ASCII_DIGIT => {
|
||||
if is_top_level && (offset + 1 != data_len) {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: offset,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(48u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: offset,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(57u8),
|
||||
});
|
||||
vec![(Pred::And(vec![gte, lte]), offset + 1)]
|
||||
}
|
||||
Ex::ANY => {
|
||||
if is_top_level && (offset + 1 > data_len) {
|
||||
panic!();
|
||||
}
|
||||
let start = offset;
|
||||
let end = min(offset + 4, data_len);
|
||||
let mut branches = Vec::new();
|
||||
for branch_end in start + 1..end {
|
||||
branches.push((is_unicode(RangeSet::from(start..branch_end)), branch_end))
|
||||
}
|
||||
|
||||
if is_top_level {
|
||||
assert!(branches.len() == 1);
|
||||
}
|
||||
branches
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_int() {
|
||||
use rand::{distr::Alphanumeric, rng, Rng};
|
||||
|
||||
let grammar = include_str!("json_int.pest");
|
||||
|
||||
// Parse the grammar file into Pairs (the grammar’s own parse tree)
|
||||
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
|
||||
|
||||
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
|
||||
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
|
||||
|
||||
// 4) Convert the parsed pairs into the stable AST representation
|
||||
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
|
||||
|
||||
let exp = build_rules(&rules_ast);
|
||||
|
||||
// 5) Inspect the AST however you like For a quick look, the Debug print is the
|
||||
// safest (works across versions)
|
||||
for rule in &rules_ast {
|
||||
println!("{:#?}", rule);
|
||||
}
|
||||
|
||||
const LENGTH: usize = 7; // Adjustable constant
|
||||
|
||||
let pred = expr_to_pred(&exp, 0, LENGTH, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
|
||||
println!("{:?} and gates", circ.and_count());
|
||||
|
||||
for i in 0..1000000 {
|
||||
let s: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(LENGTH)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let out = eval_pred(pred, s.as_bytes());
|
||||
|
||||
let is_int = s.chars().all(|c| c.is_ascii_digit()) && !s.starts_with('0');
|
||||
|
||||
if out != is_int {
|
||||
println!("failed at index {:?} with {:?}", i, s);
|
||||
}
|
||||
assert_eq!(out, is_int)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_str() {
|
||||
use rand::{distr::Alphanumeric, rng, Rng};
|
||||
const LENGTH: usize = 10; // Adjustable constant
|
||||
|
||||
let grammar = include_str!("json_str.pest");
|
||||
|
||||
// Parse the grammar file into Pairs (the grammar’s own parse tree)
|
||||
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
|
||||
|
||||
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
|
||||
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
|
||||
|
||||
// 4) Convert the parsed pairs into the stable AST representation
|
||||
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
|
||||
|
||||
for rule in &rules_ast {
|
||||
println!("{:#?}", rule);
|
||||
}
|
||||
|
||||
let exp = build_rules(&rules_ast);
|
||||
|
||||
for len in LENGTH..LENGTH + 7 {
|
||||
let pred = expr_to_pred(&exp, 0, len, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
let circ = Compiler::new().compile(pred);
|
||||
|
||||
println!(
|
||||
"JSON string length: {:?}; circuit AND gate count {:?}",
|
||||
len,
|
||||
circ.and_count()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_choice() {
|
||||
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
|
||||
let b = Expr::Ident("ASCII_DIGIT".to_string());
|
||||
let rule = ast::Rule {
|
||||
name: "test".to_string(),
|
||||
ty: ast::RuleType::Atomic,
|
||||
expr: Expr::Choice(Box::new(a), Box::new(b)),
|
||||
};
|
||||
|
||||
let exp = build_rules(&vec![rule]);
|
||||
let pred = expr_to_pred(&exp, 0, 1, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
println!("pred is {:?}", pred);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seq() {
|
||||
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
|
||||
let b = Expr::Ident("ASCII_DIGIT".to_string());
|
||||
let rule = ast::Rule {
|
||||
name: "test".to_string(),
|
||||
ty: ast::RuleType::Atomic,
|
||||
expr: Expr::Seq(Box::new(a), Box::new(b)),
|
||||
};
|
||||
|
||||
let exp = build_rules(&vec![rule]);
|
||||
let pred = expr_to_pred(&exp, 0, 2, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
println!("pred is {:?}", pred);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rep() {
|
||||
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
|
||||
let b = Expr::Ident("ASCII_DIGIT".to_string());
|
||||
|
||||
let rule = ast::Rule {
|
||||
name: "test".to_string(),
|
||||
ty: ast::RuleType::Atomic,
|
||||
expr: Expr::Rep(Box::new(a)),
|
||||
};
|
||||
|
||||
let exp = build_rules(&vec![rule]);
|
||||
let pred = expr_to_pred(&exp, 0, 3, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
println!("pred is {:?}", pred);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rep_choice() {
|
||||
const LENGTH: usize = 5; // Adjustable constant
|
||||
|
||||
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
|
||||
let b = Expr::Ident("ASCII_DIGIT".to_string());
|
||||
// Number of predicates needed to represent the expressions.
|
||||
let a_weight = 2usize;
|
||||
let b_weight = 2usize;
|
||||
|
||||
let rep_a = Expr::Rep(Box::new(a));
|
||||
let rep_b = Expr::Rep(Box::new(b));
|
||||
|
||||
let rule = ast::Rule {
|
||||
name: "test".to_string(),
|
||||
ty: ast::RuleType::Atomic,
|
||||
expr: Expr::Choice(Box::new(rep_a), Box::new(rep_b)),
|
||||
};
|
||||
|
||||
let exp = build_rules(&vec![rule]);
|
||||
let pred = expr_to_pred(&exp, 0, LENGTH, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
println!("pred is {}", pred);
|
||||
// This is for sanity that no extra predicates are being added.
|
||||
assert_eq!(pred.leaves(), a_weight * LENGTH + b_weight * LENGTH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neg_choice() {
|
||||
let a = Expr::Str("4".to_string());
|
||||
let b = Expr::Str("5".to_string());
|
||||
let choice = Expr::Choice(Box::new(a), Box::new(b));
|
||||
let neg_choice = Expr::NegPred(Box::new(choice));
|
||||
let c = Expr::Str("a".to_string());
|
||||
let d = Expr::Str("BC".to_string());
|
||||
let choice2 = Expr::Choice(Box::new(c), Box::new(d));
|
||||
|
||||
let rule = ast::Rule {
|
||||
name: "test".to_string(),
|
||||
ty: ast::RuleType::Atomic,
|
||||
expr: Expr::Seq(Box::new(neg_choice), Box::new(choice2)),
|
||||
};
|
||||
|
||||
let exp = build_rules(&vec![rule]);
|
||||
let pred = expr_to_pred(&exp, 0, 2, true);
|
||||
assert!(pred.len() == 1);
|
||||
let pred = &pred[0].0;
|
||||
|
||||
println!("pred is {:?}", pred);
|
||||
assert_eq!(pred.leaves(), 4);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
// Copied from pest.json
|
||||
|
||||
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copied from pest.json
|
||||
// Replaced "string" with "X" to avoid recursion.
|
||||
|
||||
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ "X")? }
|
||||
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
|
||||
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }
|
||||
@@ -14,9 +14,6 @@ pub mod webpki;
|
||||
pub use rangeset;
|
||||
pub mod config;
|
||||
pub(crate) mod display;
|
||||
//pub mod grammar;
|
||||
pub mod json;
|
||||
pub mod predicates;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
@@ -1,790 +0,0 @@
|
||||
//! Predicate and compiler.
|
||||
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use mpz_circuits::{itybity::ToBits, ops, Circuit, CircuitBuilder, Feed, Node};
|
||||
use rangeset::RangeSet;
|
||||
|
||||
/// ddd
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum Pred {
|
||||
And(Vec<Pred>),
|
||||
Or(Vec<Pred>),
|
||||
Not(Box<Pred>),
|
||||
Atom(Atom),
|
||||
// An always-true predicate.
|
||||
True,
|
||||
// An always-false predicate.
|
||||
False,
|
||||
}
|
||||
|
||||
impl Pred {
|
||||
/// Returns sorted unique byte indices of this predicate.
|
||||
pub(crate) fn indices(&self) -> Vec<usize> {
|
||||
let mut indices = self.indices_internal(self);
|
||||
indices.sort_unstable();
|
||||
indices.dedup();
|
||||
indices
|
||||
}
|
||||
|
||||
// Returns the number of leaves (i.e atoms) the AST of this predicate has.
|
||||
pub(crate) fn leaves(&self) -> usize {
|
||||
match self {
|
||||
Pred::And(vec) => vec.iter().map(|p| p.leaves()).sum(),
|
||||
Pred::Or(vec) => vec.iter().map(|p| p.leaves()).sum(),
|
||||
Pred::Not(p) => p.leaves(),
|
||||
Pred::Atom(atom) => 1,
|
||||
Pred::True => 0,
|
||||
Pred::False => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all byte indices of the given `pred`icate.
|
||||
fn indices_internal(&self, pred: &Pred) -> Vec<usize> {
|
||||
match pred {
|
||||
Pred::And(vec) => vec
|
||||
.iter()
|
||||
.flat_map(|p| self.indices_internal(p))
|
||||
.collect::<Vec<_>>(),
|
||||
Pred::Or(vec) => vec
|
||||
.iter()
|
||||
.flat_map(|p| self.indices_internal(p))
|
||||
.collect::<Vec<_>>(),
|
||||
Pred::Not(p) => self.indices_internal(p),
|
||||
Pred::Atom(atom) => {
|
||||
let mut indices = Vec::new();
|
||||
indices.push(atom.index);
|
||||
if let Rhs::Idx(idx) = atom.rhs {
|
||||
indices.push(idx);
|
||||
}
|
||||
indices
|
||||
}
|
||||
Pred::True => vec![],
|
||||
Pred::False => vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Pred {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.fmt_with_indent(f, 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Pred {
|
||||
fn fmt_with_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
|
||||
// helper to write the current indentation
|
||||
fn pad(f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
|
||||
// 2 spaces per level; tweak as you like
|
||||
write!(f, "{:indent$}", "", indent = indent * 2)
|
||||
}
|
||||
|
||||
match self {
|
||||
Pred::And(preds) => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "And(")?;
|
||||
for p in preds {
|
||||
p.fmt_with_indent(f, indent + 1)?;
|
||||
}
|
||||
pad(f, indent)?;
|
||||
writeln!(f, ")")
|
||||
}
|
||||
Pred::Or(preds) => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "Or(")?;
|
||||
for p in preds {
|
||||
p.fmt_with_indent(f, indent + 1)?;
|
||||
}
|
||||
pad(f, indent)?;
|
||||
writeln!(f, ")")
|
||||
}
|
||||
Pred::Not(p) => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "Not(")?;
|
||||
p.fmt_with_indent(f, indent + 1)?;
|
||||
pad(f, indent)?;
|
||||
writeln!(f, ")")
|
||||
}
|
||||
Pred::Atom(a) => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "Atom({:?})", a)
|
||||
}
|
||||
Pred::True => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "True")
|
||||
}
|
||||
Pred::False => {
|
||||
pad(f, indent)?;
|
||||
writeln!(f, "False")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Atomic predicate of the form:
|
||||
/// x[index] (op) rhs
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Atom {
|
||||
/// Left-hand side byte index `i` (x_i).
|
||||
pub index: usize,
|
||||
/// Comparison operator.
|
||||
pub op: CmpOp,
|
||||
/// Right-hand side operand (constant or x_j).
|
||||
pub rhs: Rhs,
|
||||
}
|
||||
|
||||
/// ddd
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum CmpOp {
|
||||
Eq, // ==
|
||||
Ne, // !=
|
||||
Gt, // >
|
||||
Gte, // >=
|
||||
Lt, // <
|
||||
Lte, // <=
|
||||
}
|
||||
|
||||
/// RHS of a comparison
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Rhs {
|
||||
/// Byte at index
|
||||
Idx(usize),
|
||||
/// Literal constant.
|
||||
Const(u8),
|
||||
}
|
||||
|
||||
/// Compiles a predicate into a circuit.
|
||||
pub struct Compiler {
|
||||
/// A <byte index, circuit feeds> map.
|
||||
map: HashMap<usize, [Node<Feed>; 8]>,
|
||||
}
|
||||
|
||||
impl Compiler {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compiles the given predicate into a circuit, consuming the
|
||||
/// compiler.
|
||||
pub(crate) fn compile(&mut self, pred: &Pred) -> Circuit {
|
||||
let mut builder = CircuitBuilder::new();
|
||||
|
||||
for idx in pred.indices() {
|
||||
let feeds = (0..8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
self.map.insert(idx, feeds.try_into().unwrap());
|
||||
}
|
||||
|
||||
let out = self.process(&mut builder, pred);
|
||||
|
||||
builder.add_output(out);
|
||||
builder.build().unwrap()
|
||||
}
|
||||
|
||||
// Processes a single predicate.
|
||||
fn process(&mut self, builder: &mut CircuitBuilder, pred: &Pred) -> Node<Feed> {
|
||||
match pred {
|
||||
Pred::And(vec) => {
|
||||
let out = vec
|
||||
.iter()
|
||||
.map(|p| self.process(builder, p))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ops::all(builder, &out)
|
||||
}
|
||||
Pred::Or(vec) => {
|
||||
let out = vec
|
||||
.iter()
|
||||
.map(|p| self.process(builder, p))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ops::any(builder, &out)
|
||||
}
|
||||
Pred::Not(p) => {
|
||||
let pred_out = self.process(builder, p);
|
||||
let inv = ops::inv(builder, [pred_out]);
|
||||
inv[0]
|
||||
}
|
||||
Pred::Atom(atom) => {
|
||||
let lhs = self.map.get(&atom.index).unwrap().clone();
|
||||
let rhs = match atom.rhs {
|
||||
Rhs::Const(c) => const_feeds(builder, c),
|
||||
Rhs::Idx(s) => self.map.get(&s).unwrap().clone(),
|
||||
};
|
||||
match atom.op {
|
||||
CmpOp::Eq => ops::eq(builder, lhs, rhs),
|
||||
CmpOp::Ne => ops::neq(builder, lhs, rhs),
|
||||
CmpOp::Lt => ops::lt(builder, lhs, rhs),
|
||||
CmpOp::Lte => ops::lte(builder, lhs, rhs),
|
||||
CmpOp::Gt => ops::gt(builder, lhs, rhs),
|
||||
CmpOp::Gte => ops::gte(builder, lhs, rhs),
|
||||
}
|
||||
}
|
||||
Pred::True => builder.get_const_one(),
|
||||
Pred::False => builder.get_const_zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns circuit feeds for the given constant u8 value.
|
||||
fn const_feeds(builder: &CircuitBuilder, cnst: u8) -> [Node<Feed>; 8] {
|
||||
cnst.iter_lsb0()
|
||||
.map(|b| {
|
||||
if b {
|
||||
builder.get_const_one()
|
||||
} else {
|
||||
builder.get_const_zero()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("u8 has 8 feeds")
|
||||
}
|
||||
|
||||
// Evaluates the predicate on the input `data`.
|
||||
pub(crate) fn eval_pred(pred: &Pred, data: &[u8]) -> bool {
|
||||
match pred {
|
||||
Pred::And(vec) => vec.iter().map(|p| eval_pred(p, data)).all(|b| b),
|
||||
Pred::Or(vec) => vec.iter().map(|p| eval_pred(p, data)).any(|b| b),
|
||||
Pred::Not(p) => !eval_pred(p, data),
|
||||
Pred::Atom(atom) => {
|
||||
let lhs = data[atom.index];
|
||||
let rhs = match atom.rhs {
|
||||
Rhs::Const(c) => c,
|
||||
Rhs::Idx(s) => data[s],
|
||||
};
|
||||
match atom.op {
|
||||
CmpOp::Eq => lhs == rhs,
|
||||
CmpOp::Ne => lhs != rhs,
|
||||
CmpOp::Lt => lhs < rhs,
|
||||
CmpOp::Lte => lhs <= rhs,
|
||||
CmpOp::Gt => lhs > rhs,
|
||||
CmpOp::Gte => lhs >= rhs,
|
||||
}
|
||||
}
|
||||
Pred::True => true,
|
||||
Pred::False => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a predicate that an ascii integer is contained in the ranges.
|
||||
fn is_ascii_integer(range: RangeSet<usize>) -> Pred {
|
||||
let mut preds = Vec::new();
|
||||
for idx in range.iter() {
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: idx,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(57u8),
|
||||
});
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: idx,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(48u8),
|
||||
});
|
||||
preds.push(Pred::And(vec![lte, gte]));
|
||||
}
|
||||
Pred::And(preds)
|
||||
}
|
||||
|
||||
/// Builds a predicate that a valid HTTP header value is contained in the
|
||||
/// ranges.
|
||||
fn is_valid_http_header_value(range: RangeSet<usize>) -> Pred {
|
||||
let mut preds = Vec::new();
|
||||
for idx in range.iter() {
|
||||
let ne = Pred::Atom(Atom {
|
||||
index: idx,
|
||||
op: CmpOp::Ne,
|
||||
// ascii code for carriage return \r
|
||||
rhs: Rhs::Const(13u8),
|
||||
});
|
||||
preds.push(ne);
|
||||
}
|
||||
Pred::And(preds)
|
||||
}
|
||||
|
||||
/// Builds a predicate that a valid JSON string is contained in the
|
||||
/// ranges.
|
||||
fn is_valid_json_string(range: RangeSet<usize>) -> Pred {
|
||||
assert!(
|
||||
range.len_ranges() == 1,
|
||||
"only a contiguous range is allowed"
|
||||
);
|
||||
|
||||
const BACKSLASH: u8 = 92;
|
||||
|
||||
// check if all unicode chars are allowed
|
||||
let mut preds = Vec::new();
|
||||
|
||||
// Find all /u sequences
|
||||
for (i, idx) in range.iter().enumerate() {
|
||||
if i == range.len() - 1 {
|
||||
// if this is a last char, skip it
|
||||
continue;
|
||||
}
|
||||
let is_backslash = Pred::Atom(Atom {
|
||||
index: idx,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Const(BACKSLASH),
|
||||
});
|
||||
}
|
||||
Pred::And(preds)
|
||||
}
|
||||
|
||||
// Returns a predicate that a unicode char is contained in the range
|
||||
pub(crate) fn is_unicode(range: RangeSet<usize>) -> Pred {
|
||||
assert!(range.len() <= 4);
|
||||
match range.len() {
|
||||
1 => is_1_byte_unicode(range.max().unwrap()),
|
||||
2 => is_2_byte_unicode(range),
|
||||
3 => is_3_byte_unicode(range),
|
||||
4 => is_4_byte_unicode(range),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_1_byte_unicode(pos: usize) -> Pred {
|
||||
Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(127u8),
|
||||
})
|
||||
}
|
||||
|
||||
fn is_2_byte_unicode(range: RangeSet<usize>) -> Pred {
|
||||
assert!(range.len() == 2);
|
||||
let mut iter = range.iter();
|
||||
|
||||
// should be 110xxxxx
|
||||
let first = iter.next().unwrap();
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(0xC0),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(0xDF),
|
||||
});
|
||||
|
||||
let second = iter.next().unwrap();
|
||||
Pred::And(vec![lte, gte, is_unicode_continuation(second)])
|
||||
}
|
||||
|
||||
fn is_3_byte_unicode(range: RangeSet<usize>) -> Pred {
|
||||
assert!(range.len() == 3);
|
||||
let mut iter = range.iter();
|
||||
|
||||
let first = iter.next().unwrap();
|
||||
// should be 1110xxxx
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(0xE0),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(0xEF),
|
||||
});
|
||||
|
||||
let second = iter.next().unwrap();
|
||||
let third = iter.next().unwrap();
|
||||
|
||||
Pred::And(vec![
|
||||
lte,
|
||||
gte,
|
||||
is_unicode_continuation(second),
|
||||
is_unicode_continuation(third),
|
||||
])
|
||||
}
|
||||
|
||||
fn is_4_byte_unicode(range: RangeSet<usize>) -> Pred {
|
||||
assert!(range.len() == 4);
|
||||
let mut iter = range.iter();
|
||||
|
||||
let first = iter.next().unwrap();
|
||||
// should be 11110xxx
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(0xF0),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: first,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(0xF7),
|
||||
});
|
||||
|
||||
let second = iter.next().unwrap();
|
||||
let third = iter.next().unwrap();
|
||||
let fourth = iter.next().unwrap();
|
||||
|
||||
Pred::And(vec![
|
||||
lte,
|
||||
gte,
|
||||
is_unicode_continuation(second),
|
||||
is_unicode_continuation(third),
|
||||
is_unicode_continuation(fourth),
|
||||
])
|
||||
}
|
||||
|
||||
fn is_unicode_continuation(pos: usize) -> Pred {
|
||||
// should be 10xxxxxx
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(0x80),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(0xBF),
|
||||
});
|
||||
Pred::And(vec![lte, gte])
|
||||
}
|
||||
|
||||
fn is_ascii_hex_digit(pos: usize) -> Pred {
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(48u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(57u8),
|
||||
});
|
||||
let is_digit = Pred::And(vec![lte, gte]);
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(65u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(70u8),
|
||||
});
|
||||
let is_upper = Pred::And(vec![lte, gte]);
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(97u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(102u8),
|
||||
});
|
||||
let is_lower = Pred::And(vec![lte, gte]);
|
||||
Pred::Or(vec![is_digit, is_lower, is_upper])
|
||||
}
|
||||
|
||||
fn is_ascii_lowercase(pos: usize) -> Pred {
|
||||
let gte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Const(48u8),
|
||||
});
|
||||
let lte = Pred::Atom(Atom {
|
||||
index: pos,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Const(57u8),
|
||||
});
|
||||
Pred::And(vec![lte, gte])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use mpz_circuits::evaluate;
|
||||
use rand::rng;
|
||||
|
||||
#[test]
|
||||
fn test_and() {
|
||||
let pred = Pred::And(vec![
|
||||
Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(300),
|
||||
}),
|
||||
Pred::Atom(Atom {
|
||||
index: 200,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Const(2u8),
|
||||
}),
|
||||
]);
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [1u8, 2, 3]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 3, 3]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_or() {
|
||||
let pred = Pred::Or(vec![
|
||||
Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(300),
|
||||
}),
|
||||
Pred::Atom(Atom {
|
||||
index: 200,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Const(2u8),
|
||||
}),
|
||||
]);
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [1u8, 0, 3]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 3, 0]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not() {
|
||||
let pred = Pred::Not(Box::new(Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(300),
|
||||
})));
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [5u8, 3]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
// Tests when RHS is a const.
|
||||
#[test]
|
||||
fn test_rhs_const() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Const(22u8),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, 5u8).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, 23u8).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
// Tests when RHS is an index.
|
||||
#[test]
|
||||
fn test_rhs_idx() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(200),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, 5u8, 10u8).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, 23u8, 5u8).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
// Tests when same index is used in the predicate.
|
||||
#[test]
|
||||
fn test_same_idx() {
|
||||
let pred1 = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Idx(100),
|
||||
});
|
||||
|
||||
let pred2 = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(100),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred1);
|
||||
let out: bool = evaluate!(circ, 5u8).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let circ = Compiler::new().compile(&pred2);
|
||||
let out: bool = evaluate!(circ, 5u8).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_eq() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Eq,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [5u8, 5]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_neq() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Ne,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [5u8, 6]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_gt() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Gt,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [7u8, 6]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_gte() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Gte,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [7u8, 7]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [0u8, 1]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_lt() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lt,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [2u8, 7]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atom_lte() {
|
||||
let pred = Pred::Atom(Atom {
|
||||
index: 100,
|
||||
op: CmpOp::Lte,
|
||||
rhs: Rhs::Idx(300),
|
||||
});
|
||||
|
||||
let circ = Compiler::new().compile(&pred);
|
||||
let out: bool = evaluate!(circ, [2u8, 2]).unwrap();
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_ascii_integer() {
|
||||
let text = "text with integers 123456 text";
|
||||
let pos = text.find("123456").unwrap();
|
||||
|
||||
let pred = is_ascii_integer(RangeSet::from(pos..pos + 6));
|
||||
let bytes: &[u8] = text.as_bytes();
|
||||
|
||||
let out = eval_pred(&pred, bytes);
|
||||
assert_eq!(out, true);
|
||||
|
||||
let out = eval_pred(&pred, &[&[0u8], bytes].concat());
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_http_header_value() {
|
||||
let valid = "valid header value";
|
||||
let invalid = "invalid header \r value";
|
||||
|
||||
let pred = is_valid_http_header_value(RangeSet::from(0..valid.len()));
|
||||
let out: bool = eval_pred(&pred, valid.as_bytes());
|
||||
assert_eq!(out, true);
|
||||
|
||||
let pred = is_valid_http_header_value(RangeSet::from(0..invalid.len()));
|
||||
let out = eval_pred(&pred, invalid.as_bytes());
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_unicode() {
|
||||
use rand::{distr::Alphanumeric, rng, Rng};
|
||||
let mut rng = rng();
|
||||
|
||||
for _ in 0..1000000 {
|
||||
let mut s = String::from("HelloWorld");
|
||||
let insert_pos = 5; // logical character index (after "Hello")
|
||||
|
||||
let byte_index = s
|
||||
.char_indices()
|
||||
.nth(insert_pos)
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or_else(|| s.len());
|
||||
// Pick a random Unicode scalar value (0x0000..=0x10FFFF)
|
||||
// Retry if it's in the surrogate range (U+D800..=U+DFFF)
|
||||
let c = loop {
|
||||
let code = rng.random_range(0x0000u32..=0x10FFFF);
|
||||
if !(0xD800..=0xDFFF).contains(&code) {
|
||||
if let Some(ch) = char::from_u32(code) {
|
||||
break ch;
|
||||
}
|
||||
}
|
||||
};
|
||||
let mut buf = [0u8; 4]; // max UTF-8 length
|
||||
let encoded = c.encode_utf8(&mut buf); // returns &str
|
||||
let len = encoded.len();
|
||||
|
||||
s.insert_str(byte_index, &c.to_string());
|
||||
|
||||
let pred = is_unicode(RangeSet::from(byte_index..byte_index + len));
|
||||
let out = eval_pred(&pred, s.as_bytes());
|
||||
assert_eq!(out, true);
|
||||
}
|
||||
|
||||
let bad_unicode = 255u8;
|
||||
let pred = is_unicode(RangeSet::from(0..1));
|
||||
let out = eval_pred(&pred, &[bad_unicode]);
|
||||
assert_eq!(out, false);
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,11 @@ mod tls;
|
||||
|
||||
use std::{fmt, ops::Range};
|
||||
|
||||
use rangeset::{Difference, IndexRanges, RangeSet, Union};
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Index, Set},
|
||||
set::RangeSet,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::connection::TranscriptLength;
|
||||
@@ -106,8 +110,14 @@ impl Transcript {
|
||||
}
|
||||
|
||||
Some(
|
||||
Subsequence::new(idx.clone(), data.index_ranges(idx))
|
||||
.expect("data is same length as index"),
|
||||
Subsequence::new(
|
||||
idx.clone(),
|
||||
data.index(idx).fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
)
|
||||
.expect("data is same length as index"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -129,11 +139,11 @@ impl Transcript {
|
||||
let mut sent = vec![0; self.sent.len()];
|
||||
let mut received = vec![0; self.received.len()];
|
||||
|
||||
for range in sent_idx.iter_ranges() {
|
||||
for range in sent_idx.iter() {
|
||||
sent[range.clone()].copy_from_slice(&self.sent[range]);
|
||||
}
|
||||
|
||||
for range in recv_idx.iter_ranges() {
|
||||
for range in recv_idx.iter() {
|
||||
received[range.clone()].copy_from_slice(&self.received[range]);
|
||||
}
|
||||
|
||||
@@ -186,12 +196,20 @@ pub struct CompressedPartialTranscript {
|
||||
impl From<PartialTranscript> for CompressedPartialTranscript {
|
||||
fn from(uncompressed: PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent_authed: uncompressed
|
||||
.sent
|
||||
.index_ranges(&uncompressed.sent_authed_idx),
|
||||
sent_authed: uncompressed.sent.index(&uncompressed.sent_authed_idx).fold(
|
||||
Vec::new(),
|
||||
|mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
},
|
||||
),
|
||||
received_authed: uncompressed
|
||||
.received
|
||||
.index_ranges(&uncompressed.received_authed_idx),
|
||||
.index(&uncompressed.received_authed_idx)
|
||||
.fold(Vec::new(), |mut acc, s| {
|
||||
acc.extend_from_slice(s);
|
||||
acc
|
||||
}),
|
||||
sent_idx: uncompressed.sent_authed_idx,
|
||||
recv_idx: uncompressed.received_authed_idx,
|
||||
sent_total: uncompressed.sent.len(),
|
||||
@@ -207,7 +225,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.sent_idx.iter_ranges() {
|
||||
for range in compressed.sent_idx.iter() {
|
||||
sent[range.clone()]
|
||||
.copy_from_slice(&compressed.sent_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -215,7 +233,7 @@ impl From<CompressedPartialTranscript> for PartialTranscript {
|
||||
|
||||
let mut offset = 0;
|
||||
|
||||
for range in compressed.recv_idx.iter_ranges() {
|
||||
for range in compressed.recv_idx.iter() {
|
||||
received[range.clone()]
|
||||
.copy_from_slice(&compressed.received_authed[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
@@ -304,12 +322,16 @@ impl PartialTranscript {
|
||||
|
||||
/// Returns the index of sent data which haven't been authenticated.
|
||||
pub fn sent_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.sent.len()).difference(&self.sent_authed_idx)
|
||||
(0..self.sent.len())
|
||||
.difference(&self.sent_authed_idx)
|
||||
.into_set()
|
||||
}
|
||||
|
||||
/// Returns the index of received data which haven't been authenticated.
|
||||
pub fn received_unauthed(&self) -> RangeSet<usize> {
|
||||
(0..self.received.len()).difference(&self.received_authed_idx)
|
||||
(0..self.received.len())
|
||||
.difference(&self.received_authed_idx)
|
||||
.into_set()
|
||||
}
|
||||
|
||||
/// Returns an iterator over the authenticated data in the transcript.
|
||||
@@ -319,7 +341,7 @@ impl PartialTranscript {
|
||||
Direction::Received => (&self.received, &self.received_authed_idx),
|
||||
};
|
||||
|
||||
authed.iter().map(|i| data[i])
|
||||
authed.iter_values().map(move |i| data[i])
|
||||
}
|
||||
|
||||
/// Unions the authenticated data of this transcript with another.
|
||||
@@ -339,24 +361,20 @@ impl PartialTranscript {
|
||||
"received data are not the same length"
|
||||
);
|
||||
|
||||
for range in other
|
||||
.sent_authed_idx
|
||||
.difference(&self.sent_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
for range in other.sent_authed_idx.difference(&self.sent_authed_idx) {
|
||||
self.sent[range.clone()].copy_from_slice(&other.sent[range]);
|
||||
}
|
||||
|
||||
for range in other
|
||||
.received_authed_idx
|
||||
.difference(&self.received_authed_idx)
|
||||
.iter_ranges()
|
||||
{
|
||||
self.received[range.clone()].copy_from_slice(&other.received[range]);
|
||||
}
|
||||
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&other.sent_authed_idx);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&other.received_authed_idx);
|
||||
self.sent_authed_idx.union_mut(&other.sent_authed_idx);
|
||||
self.received_authed_idx
|
||||
.union_mut(&other.received_authed_idx);
|
||||
}
|
||||
|
||||
/// Unions an authenticated subsequence into this transcript.
|
||||
@@ -368,11 +386,11 @@ impl PartialTranscript {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
seq.copy_to(&mut self.sent);
|
||||
self.sent_authed_idx = self.sent_authed_idx.union(&seq.idx);
|
||||
self.sent_authed_idx.union_mut(&seq.idx);
|
||||
}
|
||||
Direction::Received => {
|
||||
seq.copy_to(&mut self.received);
|
||||
self.received_authed_idx = self.received_authed_idx.union(&seq.idx);
|
||||
self.received_authed_idx.union_mut(&seq.idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,10 +401,10 @@ impl PartialTranscript {
|
||||
///
|
||||
/// * `value` - The value to set the unauthenticated bytes to
|
||||
pub fn set_unauthed(&mut self, value: u8) {
|
||||
for range in self.sent_unauthed().iter_ranges() {
|
||||
for range in self.sent_unauthed().iter() {
|
||||
self.sent[range].fill(value);
|
||||
}
|
||||
for range in self.received_unauthed().iter_ranges() {
|
||||
for range in self.received_unauthed().iter() {
|
||||
self.received[range].fill(value);
|
||||
}
|
||||
}
|
||||
@@ -401,13 +419,13 @@ impl PartialTranscript {
|
||||
pub fn set_unauthed_range(&mut self, value: u8, direction: Direction, range: Range<usize>) {
|
||||
match direction {
|
||||
Direction::Sent => {
|
||||
for range in range.difference(&self.sent_authed_idx).iter_ranges() {
|
||||
self.sent[range].fill(value);
|
||||
for r in range.difference(&self.sent_authed_idx) {
|
||||
self.sent[r].fill(value);
|
||||
}
|
||||
}
|
||||
Direction::Received => {
|
||||
for range in range.difference(&self.received_authed_idx).iter_ranges() {
|
||||
self.received[range].fill(value);
|
||||
for r in range.difference(&self.received_authed_idx) {
|
||||
self.received[r].fill(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -485,7 +503,7 @@ impl Subsequence {
|
||||
/// Panics if the subsequence ranges are out of bounds.
|
||||
pub(crate) fn copy_to(&self, dest: &mut [u8]) {
|
||||
let mut offset = 0;
|
||||
for range in self.idx.iter_ranges() {
|
||||
for range in self.idx.iter() {
|
||||
dest[range.clone()].copy_from_slice(&self.data[offset..offset + range.len()]);
|
||||
offset += range.len();
|
||||
}
|
||||
@@ -610,12 +628,7 @@ mod validation {
|
||||
mut partial_transcript: CompressedPartialTranscriptUnchecked,
|
||||
) {
|
||||
// Change the total to be less than the last range's end bound.
|
||||
let end = partial_transcript
|
||||
.sent_idx
|
||||
.iter_ranges()
|
||||
.next_back()
|
||||
.unwrap()
|
||||
.end;
|
||||
let end = partial_transcript.sent_idx.iter().next_back().unwrap().end;
|
||||
|
||||
partial_transcript.sent_total = end - 1;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
use rangeset::{ToRangeSet, UnionMut};
|
||||
use rangeset::set::ToRangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -103,7 +103,7 @@ impl EncodingProof {
|
||||
}
|
||||
|
||||
expected_leaf.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
encoder.encode_data(*direction, range.clone(), &data[range], &mut expected_leaf);
|
||||
}
|
||||
expected_leaf.extend_from_slice(blinder.as_bytes());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bimap::BiMap;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -99,7 +99,7 @@ impl EncodingTree {
|
||||
let blinder: Blinder = rand::random();
|
||||
|
||||
encoding.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
provider
|
||||
.provide_encoding(direction, range, &mut encoding)
|
||||
.map_err(|_| EncodingTreeError::MissingEncoding { index: idx.clone() })?;
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
//! Transcript proofs.
|
||||
|
||||
use rangeset::{Cover, Difference, Subset, ToRangeSet, UnionMut};
|
||||
use rangeset::{
|
||||
iter::RangeIterator,
|
||||
ops::{Cover, Set},
|
||||
set::ToRangeSet,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashSet, fmt};
|
||||
|
||||
@@ -144,7 +148,7 @@ impl TranscriptProof {
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
buffer.extend_from_slice(&plaintext[range]);
|
||||
}
|
||||
|
||||
@@ -366,7 +370,7 @@ impl<'a> TranscriptProofBuilder<'a> {
|
||||
if idx.is_subset(committed) {
|
||||
self.query_idx.union(&direction, &idx);
|
||||
} else {
|
||||
let missing = idx.difference(committed);
|
||||
let missing = idx.difference(committed).into_set();
|
||||
return Err(TranscriptProofBuilderError::new(
|
||||
BuilderErrorKind::MissingCommitment,
|
||||
format!(
|
||||
@@ -582,7 +586,7 @@ impl fmt::Display for TranscriptProofBuilderError {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::prelude::*;
|
||||
use rstest::rstest;
|
||||
use tlsn_data_fixtures::http::{request::GET_WITH_HEADER, response::OK_JSON};
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ fn prepare_zk_proof_input(
|
||||
hasher.update(&blinder);
|
||||
let computed_hash = hasher.finalize();
|
||||
|
||||
if committed_hash != computed_hash.as_slice() {
|
||||
if committed_hash != computed_hash.as_ref() as &[u8] {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Computed hash does not match committed hash"
|
||||
));
|
||||
|
||||
@@ -9,6 +9,7 @@ pub const DEFAULT_UPLOAD_SIZE: usize = 1024;
|
||||
pub const DEFAULT_DOWNLOAD_SIZE: usize = 4096;
|
||||
pub const DEFAULT_DEFER_DECRYPTION: bool = true;
|
||||
pub const DEFAULT_MEMORY_PROFILE: bool = false;
|
||||
pub const DEFAULT_REVEAL_ALL: bool = false;
|
||||
|
||||
pub const WARM_UP_BENCH: Bench = Bench {
|
||||
group: None,
|
||||
@@ -20,6 +21,7 @@ pub const WARM_UP_BENCH: Bench = Bench {
|
||||
download_size: 4096,
|
||||
defer_decryption: true,
|
||||
memory_profile: false,
|
||||
reveal_all: true,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -79,6 +81,8 @@ pub struct BenchGroupItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -97,6 +101,8 @@ pub struct BenchItem {
|
||||
pub defer_decryption: Option<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Option<bool>,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: Option<bool>,
|
||||
}
|
||||
|
||||
impl BenchItem {
|
||||
@@ -132,6 +138,10 @@ impl BenchItem {
|
||||
if self.memory_profile.is_none() {
|
||||
self.memory_profile = group.memory_profile;
|
||||
}
|
||||
|
||||
if self.reveal_all.is_none() {
|
||||
self.reveal_all = group.reveal_all;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_bench(&self) -> Bench {
|
||||
@@ -145,6 +155,7 @@ impl BenchItem {
|
||||
download_size: self.download_size.unwrap_or(DEFAULT_DOWNLOAD_SIZE),
|
||||
defer_decryption: self.defer_decryption.unwrap_or(DEFAULT_DEFER_DECRYPTION),
|
||||
memory_profile: self.memory_profile.unwrap_or(DEFAULT_MEMORY_PROFILE),
|
||||
reveal_all: self.reveal_all.unwrap_or(DEFAULT_REVEAL_ALL),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -164,6 +175,8 @@ pub struct Bench {
|
||||
pub defer_decryption: bool,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: bool,
|
||||
#[serde(rename = "reveal-all")]
|
||||
pub reveal_all: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -22,7 +22,10 @@ pub enum CmdOutput {
|
||||
GetTests(Vec<String>),
|
||||
Test(TestOutput),
|
||||
Bench(BenchOutput),
|
||||
Fail { reason: Option<String> },
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
Fail {
|
||||
reason: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -98,14 +98,27 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result<Prove
|
||||
|
||||
let mut builder = ProveConfig::builder(prover.transcript());
|
||||
|
||||
// When reveal_all is false (the default), we exclude 1 byte to avoid the
|
||||
// reveal-all optimization and benchmark the realistic ZK authentication path.
|
||||
let reveal_sent_range = if config.reveal_all {
|
||||
0..sent_len
|
||||
} else {
|
||||
0..sent_len.saturating_sub(1)
|
||||
};
|
||||
let reveal_recv_range = if config.reveal_all {
|
||||
0..recv_len
|
||||
} else {
|
||||
0..recv_len.saturating_sub(1)
|
||||
};
|
||||
|
||||
builder
|
||||
.server_identity()
|
||||
.reveal_sent(&(0..sent_len))?
|
||||
.reveal_recv(&(0..recv_len))?;
|
||||
.reveal_sent(&reveal_sent_range)?
|
||||
.reveal_recv(&reveal_recv_range)?;
|
||||
|
||||
let config = builder.build()?;
|
||||
let prove_config = builder.build()?;
|
||||
|
||||
prover.prove(&config).await?;
|
||||
prover.prove(&prove_config).await?;
|
||||
prover.close().await?;
|
||||
|
||||
let time_total = time_start.elapsed().as_millis();
|
||||
|
||||
@@ -7,10 +7,9 @@ publish = false
|
||||
[dependencies]
|
||||
tlsn-harness-core = { workspace = true }
|
||||
# tlsn-server-fixture = { workspace = true }
|
||||
charming = { version = "0.5.1", features = ["ssr"] }
|
||||
csv = "1.3.0"
|
||||
charming = { version = "0.6.0", features = ["ssr"] }
|
||||
clap = { workspace = true, features = ["derive", "env"] }
|
||||
itertools = "0.14.0"
|
||||
polars = { version = "0.44", features = ["csv", "lazy"] }
|
||||
toml = { workspace = true }
|
||||
|
||||
|
||||
|
||||
111
crates/harness/plot/README.md
Normal file
111
crates/harness/plot/README.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# TLSNotary Benchmark Plot Tool
|
||||
|
||||
Generates interactive HTML and SVG plots from TLSNotary benchmark results. Supports comparing multiple benchmark runs (e.g., before/after optimization, native vs browser).
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot <TOML> <CSV>... [OPTIONS]
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
- `<TOML>` - Path to Bench.toml file defining benchmark structure
|
||||
- `<CSV>...` - One or more CSV files with benchmark results
|
||||
|
||||
### Options
|
||||
|
||||
- `-l, --labels <LABEL>...` - Labels for each dataset (optional)
|
||||
- If omitted, datasets are labeled "Dataset 1", "Dataset 2", etc.
|
||||
- Number of labels must match number of CSV files
|
||||
- `--min-max-band` - Add min/max bands to plots showing variance
|
||||
- `-h, --help` - Print help information
|
||||
|
||||
## Examples
|
||||
|
||||
### Single Dataset
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml results.csv
|
||||
```
|
||||
|
||||
Generates plots from a single benchmark run.
|
||||
|
||||
### Compare Two Runs
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml before.csv after.csv \
|
||||
--labels "Before Optimization" "After Optimization"
|
||||
```
|
||||
|
||||
Overlays two datasets to compare performance improvements.
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml native.csv browser.csv wasm.csv \
|
||||
--labels "Native" "Browser" "WASM"
|
||||
```
|
||||
|
||||
Compare three different runtime environments.
|
||||
|
||||
### With Min/Max Bands
|
||||
|
||||
```bash
|
||||
tlsn-harness-plot bench.toml run1.csv run2.csv \
|
||||
--labels "Config A" "Config B" \
|
||||
--min-max-band
|
||||
```
|
||||
|
||||
Shows variance ranges for each dataset.
|
||||
|
||||
## Output Files
|
||||
|
||||
The tool generates two files per benchmark group:
|
||||
|
||||
- `<output>.html` - Interactive HTML chart (zoomable, hoverable)
|
||||
- `<output>.svg` - Static SVG image for documentation
|
||||
|
||||
Default output filenames:
|
||||
- `runtime_vs_bandwidth.{html,svg}` - When `protocol_latency` is defined in group
|
||||
- `runtime_vs_latency.{html,svg}` - When `bandwidth` is defined in group
|
||||
|
||||
## Plot Format
|
||||
|
||||
Each dataset displays:
|
||||
- **Solid line** - Total runtime (preprocessing + online phase)
|
||||
- **Dashed line** - Online phase only
|
||||
- **Shaded area** (optional) - Min/max variance bands
|
||||
|
||||
Different datasets automatically use distinct colors for easy comparison.
|
||||
|
||||
## CSV Format
|
||||
|
||||
Expected columns in each CSV file:
|
||||
- `group` - Benchmark group name (must match TOML)
|
||||
- `bandwidth` - Network bandwidth in Kbps (for bandwidth plots)
|
||||
- `latency` - Network latency in ms (for latency plots)
|
||||
- `time_preprocess` - Preprocessing time in ms
|
||||
- `time_online` - Online phase time in ms
|
||||
- `time_total` - Total runtime in ms
|
||||
|
||||
## TOML Format
|
||||
|
||||
The benchmark TOML file defines groups with either:
|
||||
|
||||
```toml
|
||||
[[group]]
|
||||
name = "my_benchmark"
|
||||
protocol_latency = 50 # Fixed latency for bandwidth plots
|
||||
# OR
|
||||
bandwidth = 10000 # Fixed bandwidth for latency plots
|
||||
```
|
||||
|
||||
All datasets must use the same TOML file to ensure consistent benchmark structure.
|
||||
|
||||
## Tips
|
||||
|
||||
- Use descriptive labels to make plots self-documenting
|
||||
- Keep CSV files from the same benchmark configuration for valid comparisons
|
||||
- Min/max bands are useful for showing stability but can clutter plots with many datasets
|
||||
- Interactive HTML plots support zooming and hovering for detailed values
|
||||
@@ -1,17 +1,18 @@
|
||||
use std::f32;
|
||||
|
||||
use charming::{
|
||||
Chart, HtmlRenderer,
|
||||
Chart, HtmlRenderer, ImageRenderer,
|
||||
component::{Axis, Legend, Title},
|
||||
element::{AreaStyle, LineStyle, NameLocation, Orient, TextStyle, Tooltip, Trigger},
|
||||
element::{
|
||||
AreaStyle, ItemStyle, LineStyle, LineStyleType, NameLocation, Orient, TextStyle, Tooltip,
|
||||
Trigger,
|
||||
},
|
||||
series::Line,
|
||||
theme::Theme,
|
||||
};
|
||||
use clap::Parser;
|
||||
use harness_core::bench::{BenchItems, Measurement};
|
||||
use itertools::Itertools;
|
||||
|
||||
const THEME: Theme = Theme::Default;
|
||||
use harness_core::bench::BenchItems;
|
||||
use polars::prelude::*;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about)]
|
||||
@@ -19,72 +20,131 @@ struct Cli {
|
||||
/// Path to the Bench.toml file with benchmark spec
|
||||
toml: String,
|
||||
|
||||
/// Path to the CSV file with benchmark results
|
||||
csv: String,
|
||||
/// Paths to CSV files with benchmark results (one or more)
|
||||
csv: Vec<String>,
|
||||
|
||||
/// Prover kind: native or browser
|
||||
#[arg(short, long, value_enum, default_value = "native")]
|
||||
prover_kind: ProverKind,
|
||||
/// Labels for each dataset (optional, defaults to "Dataset 1", "Dataset 2", etc.)
|
||||
#[arg(short, long, num_args = 0..)]
|
||||
labels: Vec<String>,
|
||||
|
||||
/// Add min/max bands to plots
|
||||
#[arg(long, default_value_t = false)]
|
||||
min_max_band: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ProverKind {
|
||||
Native,
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProverKind {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ProverKind::Native => write!(f, "Native"),
|
||||
ProverKind::Browser => write!(f, "Browser"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
let mut rdr = csv::Reader::from_path(&cli.csv)?;
|
||||
if cli.csv.is_empty() {
|
||||
return Err("At least one CSV file must be provided".into());
|
||||
}
|
||||
|
||||
// Generate labels if not provided
|
||||
let labels: Vec<String> = if cli.labels.is_empty() {
|
||||
cli.csv
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("Dataset {}", i + 1))
|
||||
.collect()
|
||||
} else if cli.labels.len() != cli.csv.len() {
|
||||
return Err(format!(
|
||||
"Number of labels ({}) must match number of CSV files ({})",
|
||||
cli.labels.len(),
|
||||
cli.csv.len()
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
cli.labels.clone()
|
||||
};
|
||||
|
||||
// Load all CSVs and add dataset label
|
||||
let mut dfs = Vec::new();
|
||||
for (csv_path, label) in cli.csv.iter().zip(labels.iter()) {
|
||||
let mut df = CsvReadOptions::default()
|
||||
.try_into_reader_with_file_path(Some(csv_path.clone().into()))?
|
||||
.finish()?;
|
||||
|
||||
let label_series = Series::new("dataset_label".into(), vec![label.as_str(); df.height()]);
|
||||
df.with_column(label_series)?;
|
||||
dfs.push(df);
|
||||
}
|
||||
|
||||
// Combine all dataframes
|
||||
let df = dfs
|
||||
.into_iter()
|
||||
.reduce(|acc, df| acc.vstack(&df).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let items: BenchItems = toml::from_str(&std::fs::read_to_string(&cli.toml)?)?;
|
||||
let groups = items.group;
|
||||
|
||||
// Prepare data for plotting.
|
||||
let all_data: Vec<Measurement> = rdr
|
||||
.deserialize::<Measurement>()
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
for group in groups {
|
||||
if group.protocol_latency.is_some() {
|
||||
let latency = group.protocol_latency.unwrap();
|
||||
plot_runtime_vs(
|
||||
&all_data,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
|r| r.bandwidth as f32 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency, {} mode", latency, cli.prover_kind),
|
||||
"runtime_vs_bandwidth.html",
|
||||
"Bandwidth (Mbps)",
|
||||
)?;
|
||||
// Determine which field varies in benches for this group
|
||||
let benches_in_group: Vec<_> = items
|
||||
.bench
|
||||
.iter()
|
||||
.filter(|b| b.group.as_deref() == Some(&group.name))
|
||||
.collect();
|
||||
|
||||
if benches_in_group.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if group.bandwidth.is_some() {
|
||||
let bandwidth = group.bandwidth.unwrap();
|
||||
// Check which field has varying values
|
||||
let bandwidth_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].bandwidth != w[1].bandwidth);
|
||||
let latency_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].protocol_latency != w[1].protocol_latency);
|
||||
let download_size_varies = benches_in_group
|
||||
.windows(2)
|
||||
.any(|w| w[0].download_size != w[1].download_size);
|
||||
|
||||
if download_size_varies {
|
||||
let upload_size = group.upload_size.unwrap_or(1024);
|
||||
plot_runtime_vs(
|
||||
&all_data,
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
|r| r.latency as f32,
|
||||
"download_size",
|
||||
1.0 / 1024.0, // bytes to KB
|
||||
"Runtime vs Response Size",
|
||||
format!("{} bytes upload size", upload_size),
|
||||
"runtime_vs_download_size",
|
||||
"Response Size (KB)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
} else if bandwidth_varies {
|
||||
let latency = group.protocol_latency.unwrap_or(50);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"bandwidth",
|
||||
1.0 / 1000.0, // Kbps to Mbps
|
||||
"Runtime vs Bandwidth",
|
||||
format!("{} ms Latency", latency),
|
||||
"runtime_vs_bandwidth",
|
||||
"Bandwidth (Mbps)",
|
||||
false, // legend on right
|
||||
)?;
|
||||
} else if latency_varies {
|
||||
let bandwidth = group.bandwidth.unwrap_or(1000);
|
||||
plot_runtime_vs(
|
||||
&df,
|
||||
&labels,
|
||||
cli.min_max_band,
|
||||
&group.name,
|
||||
"latency",
|
||||
1.0,
|
||||
"Runtime vs Latency",
|
||||
format!("{} bps bandwidth, {} mode", bandwidth, cli.prover_kind),
|
||||
"runtime_vs_latency.html",
|
||||
format!("{} bps bandwidth", bandwidth),
|
||||
"runtime_vs_latency",
|
||||
"Latency (ms)",
|
||||
true, // legend on left
|
||||
)?;
|
||||
}
|
||||
}
|
||||
@@ -92,83 +152,51 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct DataPoint {
|
||||
min: f32,
|
||||
mean: f32,
|
||||
max: f32,
|
||||
}
|
||||
|
||||
struct Points {
|
||||
preprocess: DataPoint,
|
||||
online: DataPoint,
|
||||
total: DataPoint,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn plot_runtime_vs<Fx>(
|
||||
all_data: &[Measurement],
|
||||
fn plot_runtime_vs(
|
||||
df: &DataFrame,
|
||||
labels: &[String],
|
||||
show_min_max: bool,
|
||||
group: &str,
|
||||
x_value: Fx,
|
||||
x_col: &str,
|
||||
x_scale: f32,
|
||||
title: &str,
|
||||
subtitle: String,
|
||||
output_file: &str,
|
||||
x_axis_label: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>>
|
||||
where
|
||||
Fx: Fn(&Measurement) -> f32,
|
||||
{
|
||||
fn data_point(values: &[f32]) -> DataPoint {
|
||||
let mean = values.iter().copied().sum::<f32>() / values.len() as f32;
|
||||
let max = values.iter().copied().reduce(f32::max).unwrap_or_default();
|
||||
let min = values.iter().copied().reduce(f32::min).unwrap_or_default();
|
||||
DataPoint { min, mean, max }
|
||||
}
|
||||
legend_left: bool,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
let stats_df = df
|
||||
.clone()
|
||||
.lazy()
|
||||
.filter(col("group").eq(lit(group)))
|
||||
.with_column((col(x_col).cast(DataType::Float32) * lit(x_scale)).alias("x"))
|
||||
.with_columns([
|
||||
(col("time_preprocess").cast(DataType::Float32) / lit(1000.0)).alias("preprocess"),
|
||||
(col("time_online").cast(DataType::Float32) / lit(1000.0)).alias("online"),
|
||||
(col("time_total").cast(DataType::Float32) / lit(1000.0)).alias("total"),
|
||||
])
|
||||
.group_by([col("x"), col("dataset_label")])
|
||||
.agg([
|
||||
col("preprocess").min().alias("preprocess_min"),
|
||||
col("preprocess").mean().alias("preprocess_mean"),
|
||||
col("preprocess").max().alias("preprocess_max"),
|
||||
col("online").min().alias("online_min"),
|
||||
col("online").mean().alias("online_mean"),
|
||||
col("online").max().alias("online_max"),
|
||||
col("total").min().alias("total_min"),
|
||||
col("total").mean().alias("total_mean"),
|
||||
col("total").max().alias("total_max"),
|
||||
])
|
||||
.sort(["dataset_label", "x"], Default::default())
|
||||
.collect()?;
|
||||
|
||||
let stats: Vec<(f32, Points)> = all_data
|
||||
.iter()
|
||||
.filter(|r| r.group.as_deref() == Some(group))
|
||||
.map(|r| {
|
||||
(
|
||||
x_value(r),
|
||||
r.time_preprocess as f32 / 1000.0, // ms to s
|
||||
r.time_online as f32 / 1000.0,
|
||||
r.time_total as f32 / 1000.0,
|
||||
)
|
||||
})
|
||||
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
|
||||
.chunk_by(|entry| entry.0)
|
||||
.into_iter()
|
||||
.map(|(x, group)| {
|
||||
let group_vec: Vec<_> = group.collect();
|
||||
let preprocess = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, t, _, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let online = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, t, _)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
let total = data_point(
|
||||
&group_vec
|
||||
.iter()
|
||||
.map(|(_, _, _, t)| *t)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
(
|
||||
x,
|
||||
Points {
|
||||
preprocess,
|
||||
online,
|
||||
total,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
// Build legend entries
|
||||
let mut legend_data = Vec::new();
|
||||
for label in labels {
|
||||
legend_data.push(format!("Total Mean ({})", label));
|
||||
legend_data.push(format!("Online Mean ({})", label));
|
||||
}
|
||||
|
||||
let mut chart = Chart::new()
|
||||
.title(
|
||||
@@ -179,14 +207,6 @@ where
|
||||
.subtext_style(TextStyle::new().font_size(16)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(
|
||||
Legend::new()
|
||||
.data(vec!["Preprocess Mean", "Online Mean", "Total Mean"])
|
||||
.top("80")
|
||||
.right("110")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.name(x_axis_label)
|
||||
@@ -205,73 +225,156 @@ where
|
||||
.name_text_style(TextStyle::new().font_size(21)),
|
||||
);
|
||||
|
||||
chart = add_mean_series(chart, &stats, "Preprocess Mean", |p| p.preprocess.mean);
|
||||
chart = add_mean_series(chart, &stats, "Online Mean", |p| p.online.mean);
|
||||
chart = add_mean_series(chart, &stats, "Total Mean", |p| p.total.mean);
|
||||
// Add legend with conditional positioning
|
||||
let legend = Legend::new()
|
||||
.data(legend_data)
|
||||
.top("80")
|
||||
.orient(Orient::Vertical)
|
||||
.item_gap(10);
|
||||
|
||||
if show_min_max {
|
||||
chart = add_min_max_band(
|
||||
chart,
|
||||
&stats,
|
||||
"Preprocess Min/Max",
|
||||
|p| &p.preprocess,
|
||||
"#ccc",
|
||||
);
|
||||
chart = add_min_max_band(chart, &stats, "Online Min/Max", |p| &p.online, "#ccc");
|
||||
chart = add_min_max_band(chart, &stats, "Total Min/Max", |p| &p.total, "#ccc");
|
||||
let legend = if legend_left {
|
||||
legend.left("110")
|
||||
} else {
|
||||
legend.right("110")
|
||||
};
|
||||
|
||||
chart = chart.legend(legend);
|
||||
|
||||
// Define colors for each dataset
|
||||
let colors = vec![
|
||||
"#5470c6", "#91cc75", "#fac858", "#ee6666", "#73c0de", "#3ba272", "#fc8452", "#9a60b4",
|
||||
];
|
||||
|
||||
for (idx, label) in labels.iter().enumerate() {
|
||||
let color = colors.get(idx % colors.len()).unwrap();
|
||||
|
||||
// Total time - solid line
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Mean ({})", label),
|
||||
"total_mean",
|
||||
false,
|
||||
color,
|
||||
)?;
|
||||
|
||||
// Online time - dashed line (same color as total)
|
||||
chart = add_dataset_series(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Online Mean ({})", label),
|
||||
"online_mean",
|
||||
true,
|
||||
color,
|
||||
)?;
|
||||
|
||||
if show_min_max {
|
||||
chart = add_dataset_min_max_band(
|
||||
&chart,
|
||||
&stats_df,
|
||||
label,
|
||||
&format!("Total Min/Max ({})", label),
|
||||
"total",
|
||||
color,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
// Save the chart as HTML file.
|
||||
// Save the chart as HTML file (no theme)
|
||||
HtmlRenderer::new(title, 1000, 800)
|
||||
.theme(THEME)
|
||||
.save(&chart, output_file)
|
||||
.save(&chart, &format!("{}.html", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with default theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Default)
|
||||
.save(&chart, &format!("{}.svg", output_file))
|
||||
.unwrap();
|
||||
|
||||
// Save SVG with dark theme
|
||||
ImageRenderer::new(1000, 800)
|
||||
.theme(Theme::Dark)
|
||||
.save(&chart, &format!("{}_dark.svg", output_file))
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
|
||||
fn add_mean_series(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
name: &str,
|
||||
extract: impl Fn(&Points) -> f32,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points)])
|
||||
.collect(),
|
||||
)
|
||||
.symbol_size(6),
|
||||
)
|
||||
fn add_dataset_series(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
series_name: &str,
|
||||
col_name: &str,
|
||||
dashed: bool,
|
||||
color: &str,
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let y = filtered.column(col_name)?.f32()?;
|
||||
|
||||
let data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(y.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let mut line = Line::new()
|
||||
.name(series_name)
|
||||
.data(data)
|
||||
.symbol_size(6)
|
||||
.item_style(ItemStyle::new().color(color));
|
||||
|
||||
let mut line_style = LineStyle::new();
|
||||
if dashed {
|
||||
line_style = line_style.type_(LineStyleType::Dashed);
|
||||
}
|
||||
line = line.line_style(line_style.color(color));
|
||||
|
||||
Ok(chart.clone().series(line))
|
||||
}
|
||||
|
||||
fn add_min_max_band(
|
||||
chart: Chart,
|
||||
stats: &[(f32, Points)],
|
||||
fn add_dataset_min_max_band(
|
||||
chart: &Chart,
|
||||
df: &DataFrame,
|
||||
dataset_label: &str,
|
||||
name: &str,
|
||||
extract: impl Fn(&Points) -> &DataPoint,
|
||||
col_prefix: &str,
|
||||
color: &str,
|
||||
) -> Chart {
|
||||
chart.series(
|
||||
) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
// Filter for specific dataset
|
||||
let mask = df.column("dataset_label")?.str()?.equal(dataset_label);
|
||||
let filtered = df.filter(&mask)?;
|
||||
|
||||
let x = filtered.column("x")?.f32()?;
|
||||
let min_col = filtered.column(&format!("{}_min", col_prefix))?.f32()?;
|
||||
let max_col = filtered.column(&format!("{}_max", col_prefix))?.f32()?;
|
||||
|
||||
let max_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(max_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.collect();
|
||||
|
||||
let min_data: Vec<Vec<f32>> = x
|
||||
.into_iter()
|
||||
.zip(min_col.into_iter())
|
||||
.filter_map(|(x, y)| Some(vec![x?, y?]))
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
let data: Vec<Vec<f32>> = max_data.into_iter().chain(min_data).collect();
|
||||
|
||||
Ok(chart.clone().series(
|
||||
Line::new()
|
||||
.name(name)
|
||||
.data(
|
||||
stats
|
||||
.iter()
|
||||
.map(|(x, points)| vec![*x, extract(points).max])
|
||||
.chain(
|
||||
stats
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|(x, points)| vec![*x, extract(points).min]),
|
||||
)
|
||||
.collect(),
|
||||
)
|
||||
.data(data)
|
||||
.show_symbol(false)
|
||||
.line_style(LineStyle::new().opacity(0.0))
|
||||
.area_style(AreaStyle::new().opacity(0.3).color(color)),
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
105
crates/harness/plot/data/bandwidth.ipynb
Normal file
105
crates/harness/plot/data/bandwidth.ipynb
Normal file
File diff suppressed because one or more lines are too long
163
crates/harness/plot/data/download.ipynb
Normal file
163
crates/harness/plot/data/download.ipynb
Normal file
File diff suppressed because one or more lines are too long
92
crates/harness/plot/data/latency.ipynb
Normal file
92
crates/harness/plot/data/latency.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -32,18 +32,13 @@ use crate::debug_prelude::*;
|
||||
|
||||
use crate::{cli::Route, network::Network, wasm_server::WasmServer, ws_proxy::WsProxy};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum, Default)]
|
||||
pub enum Target {
|
||||
#[default]
|
||||
Native,
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl Default for Target {
|
||||
fn default() -> Self {
|
||||
Self::Native
|
||||
}
|
||||
}
|
||||
|
||||
struct Runner {
|
||||
network: Network,
|
||||
server_fixture: ServerFixture,
|
||||
|
||||
25
crates/harness/toml/bandwidth.toml
Normal file
25
crates/harness/toml/bandwidth.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
#### Bandwidth ####
|
||||
|
||||
[[group]]
|
||||
name = "bandwidth"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 10
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 50
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 100
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 250
|
||||
|
||||
[[bench]]
|
||||
group = "bandwidth"
|
||||
bandwidth = 1000
|
||||
37
crates/harness/toml/download.toml
Normal file
37
crates/harness/toml/download.toml
Normal file
@@ -0,0 +1,37 @@
|
||||
[[group]]
|
||||
name = "download_size"
|
||||
protocol_latency = 10
|
||||
bandwidth = 200
|
||||
upload-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 1024
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 2048
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 4096
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 8192
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 16384
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 32768
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 65536
|
||||
|
||||
[[bench]]
|
||||
group = "download_size"
|
||||
download-size = 131072
|
||||
25
crates/harness/toml/latency.toml
Normal file
25
crates/harness/toml/latency.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
#### Latency ####
|
||||
|
||||
[[group]]
|
||||
name = "latency"
|
||||
bandwidth = 1000
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 10
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 25
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 50
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 100
|
||||
|
||||
[[bench]]
|
||||
group = "latency"
|
||||
protocol_latency = 200
|
||||
@@ -24,7 +24,7 @@ use std::{
|
||||
};
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use tracing::{debug, debug_span, error, trace, warn, Instrument};
|
||||
use tracing::{debug, debug_span, trace, warn, Instrument};
|
||||
|
||||
use tls_client::ClientConnection;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::{Backend, BackendError};
|
||||
use crate::{DecryptMode, EncryptMode, Error};
|
||||
#[allow(deprecated)]
|
||||
use aes_gcm::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead, Payload},
|
||||
Aes128Gcm,
|
||||
@@ -507,6 +508,7 @@ impl Encrypter {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&self.write_iv);
|
||||
nonce[4..].copy_from_slice(explicit_nonce);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap();
|
||||
// ciphertext will have the MAC appended
|
||||
@@ -568,6 +570,7 @@ impl Decrypter {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[..4].copy_from_slice(&self.write_iv);
|
||||
nonce[4..].copy_from_slice(&m.payload.0[0..8]);
|
||||
#[allow(deprecated)]
|
||||
let nonce = GenericArray::from_slice(&nonce);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, aes_payload)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_memory_core::{Vector, binary::U8};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) struct RangeMap<T> {
|
||||
@@ -77,7 +77,7 @@ where
|
||||
|
||||
pub(crate) fn index(&self, idx: &RangeSet<usize>) -> Option<Self> {
|
||||
let mut map = Vec::new();
|
||||
for idx in idx.iter_ranges() {
|
||||
for idx in idx.iter() {
|
||||
let pos = match self.map.binary_search_by(|(base, _)| base.cmp(&idx.start)) {
|
||||
Ok(i) => i,
|
||||
Err(0) => return None,
|
||||
|
||||
@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
ProverOutput,
|
||||
config::prove::ProveConfig,
|
||||
|
||||
@@ -12,7 +12,7 @@ use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Call, CallableExt, Vm};
|
||||
use rangeset::{Difference, RangeSet, Union};
|
||||
use rangeset::{iter::RangeIterator, ops::Set, set::RangeSet};
|
||||
use tlsn_core::transcript::Record;
|
||||
|
||||
use crate::transcript_internal::ReferenceMap;
|
||||
@@ -32,7 +32,7 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal)
|
||||
commit.union(reveal).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -49,7 +49,7 @@ pub(crate) fn prove_plaintext<'a>(
|
||||
vm.commit(*slice).map_err(PlaintextAuthError::vm)?;
|
||||
}
|
||||
} else {
|
||||
let private = commit.difference(reveal);
|
||||
let private = commit.difference(reveal).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.expect("all ranges are allocated")
|
||||
@@ -98,7 +98,7 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
commit.clone()
|
||||
} else {
|
||||
// The plaintext is only partially revealed, so we need to authenticate in ZK.
|
||||
commit.union(reveal)
|
||||
commit.union(reveal).into_set()
|
||||
};
|
||||
|
||||
let plaintext_refs = alloc_plaintext(vm, &alloc_ranges)?;
|
||||
@@ -123,7 +123,7 @@ pub(crate) fn verify_plaintext<'a>(
|
||||
ciphertext,
|
||||
})
|
||||
} else {
|
||||
let private = commit.difference(reveal);
|
||||
let private = commit.difference(reveal).into_set();
|
||||
for (_, slice) in plaintext_refs
|
||||
.index(&private)
|
||||
.expect("all ranges are allocated")
|
||||
@@ -175,15 +175,13 @@ fn alloc_plaintext(
|
||||
let plaintext = vm.alloc_vec::<U8>(len).map_err(PlaintextAuthError::vm)?;
|
||||
|
||||
let mut pos = 0;
|
||||
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
|
||||
move |range| {
|
||||
let chunk = plaintext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
},
|
||||
)))
|
||||
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
|
||||
let chunk = plaintext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
})))
|
||||
}
|
||||
|
||||
fn alloc_ciphertext<'a>(
|
||||
@@ -212,15 +210,13 @@ fn alloc_ciphertext<'a>(
|
||||
let ciphertext: Vector<U8> = vm.call(call).map_err(PlaintextAuthError::vm)?;
|
||||
|
||||
let mut pos = 0;
|
||||
Ok(ReferenceMap::from_iter(ranges.iter_ranges().map(
|
||||
move |range| {
|
||||
let chunk = ciphertext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
},
|
||||
)))
|
||||
Ok(ReferenceMap::from_iter(ranges.iter().map(move |range| {
|
||||
let chunk = ciphertext
|
||||
.get(pos..pos + range.len())
|
||||
.expect("length was checked");
|
||||
pos += range.len();
|
||||
(range.start, chunk)
|
||||
})))
|
||||
}
|
||||
|
||||
fn alloc_keystream<'a>(
|
||||
@@ -233,7 +229,7 @@ fn alloc_keystream<'a>(
|
||||
let mut keystream = Vec::new();
|
||||
|
||||
let mut pos = 0;
|
||||
let mut range_iter = ranges.iter_ranges();
|
||||
let mut range_iter = ranges.iter();
|
||||
let mut current_range = range_iter.next();
|
||||
for record in records {
|
||||
let mut explicit_nonce = None;
|
||||
@@ -508,7 +504,7 @@ mod tests {
|
||||
for record in records {
|
||||
let mut record_keystream = vec![0u8; record.len];
|
||||
aes_ctr_apply_keystream(&key, &iv, &record.explicit_nonce, &mut record_keystream);
|
||||
for mut range in ranges.iter_ranges() {
|
||||
for mut range in ranges.iter() {
|
||||
range.start = range.start.max(pos);
|
||||
range.end = range.end.min(pos + record.len);
|
||||
if range.start < range.end {
|
||||
|
||||
@@ -9,7 +9,7 @@ use mpz_memory_core::{
|
||||
correlated::{Delta, Key, Mac},
|
||||
};
|
||||
use rand::Rng;
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{SinkExt, stream::IoStreamExt};
|
||||
use tlsn_core::{
|
||||
|
||||
@@ -9,7 +9,7 @@ use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
};
|
||||
use mpz_vm_core::{Vm, VmError, prelude::*};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
hash::{Blinder, Hash, HashAlgId, TypedHash},
|
||||
transcript::{
|
||||
@@ -155,7 +155,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher.update(&refs.get(range).expect("plaintext refs are valid"));
|
||||
}
|
||||
|
||||
@@ -176,7 +176,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher
|
||||
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
|
||||
.map_err(HashCommitError::hasher)?;
|
||||
@@ -201,7 +201,7 @@ fn hash_commit_inner(
|
||||
Direction::Received => &refs.recv,
|
||||
};
|
||||
|
||||
for range in idx.iter_ranges() {
|
||||
for range in idx.iter() {
|
||||
hasher
|
||||
.update(vm, &refs.get(range).expect("plaintext refs are valid"))
|
||||
.map_err(HashCommitError::hasher)?;
|
||||
|
||||
@@ -2,7 +2,7 @@ use mpc_tls::SessionKeys;
|
||||
use mpz_common::Context;
|
||||
use mpz_memory_core::binary::Binary;
|
||||
use mpz_vm_core::Vm;
|
||||
use rangeset::{RangeSet, UnionMut};
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn_core::{
|
||||
VerifierOutput,
|
||||
config::prove::ProveRequest,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use futures::{AsyncReadExt, AsyncWriteExt};
|
||||
use rangeset::RangeSet;
|
||||
use rangeset::set::RangeSet;
|
||||
use tlsn::{
|
||||
config::{
|
||||
prove::ProveConfig,
|
||||
@@ -51,19 +51,11 @@ async fn test() {
|
||||
assert_eq!(server_name.as_str(), SERVER_DOMAIN);
|
||||
assert!(!partial_transcript.is_complete());
|
||||
assert_eq!(
|
||||
partial_transcript
|
||||
.sent_authed()
|
||||
.iter_ranges()
|
||||
.next()
|
||||
.unwrap(),
|
||||
partial_transcript.sent_authed().iter().next().unwrap(),
|
||||
0..10
|
||||
);
|
||||
assert_eq!(
|
||||
partial_transcript
|
||||
.received_authed()
|
||||
.iter_ranges()
|
||||
.next()
|
||||
.unwrap(),
|
||||
partial_transcript.received_authed().iter().next().unwrap(),
|
||||
0..10
|
||||
);
|
||||
|
||||
|
||||
@@ -151,9 +151,9 @@ impl From<tlsn::transcript::PartialTranscript> for PartialTranscript {
|
||||
fn from(value: tlsn::transcript::PartialTranscript) -> Self {
|
||||
Self {
|
||||
sent: value.sent_unsafe().to_vec(),
|
||||
sent_authed: value.sent_authed().iter_ranges().collect(),
|
||||
sent_authed: value.sent_authed().iter().collect(),
|
||||
recv: value.received_unsafe().to_vec(),
|
||||
recv_authed: value.received_authed().iter_ranges().collect(),
|
||||
recv_authed: value.received_authed().iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user