Compare commits

...

2 Commits

Author SHA1 Message Date
dan
a34355ef77 fix test 2025-11-19 09:17:38 +02:00
dan
bf6a23a61e convert pest ast to predicate ast 2025-11-19 08:54:21 +02:00
10 changed files with 1810 additions and 39 deletions

52
Cargo.lock generated
View File

@@ -2009,7 +2009,7 @@ checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "clmul"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"bytemuck",
"cfg-if",
@@ -4201,7 +4201,7 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrix-transpose"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"thiserror 1.0.69",
]
@@ -4258,7 +4258,7 @@ dependencies = [
[[package]]
name = "mpz-circuits"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"mpz-circuits-core",
"mpz-circuits-data",
@@ -4267,7 +4267,7 @@ dependencies = [
[[package]]
name = "mpz-circuits-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"bincode 1.3.3",
"itybity 0.3.1",
@@ -4282,7 +4282,7 @@ dependencies = [
[[package]]
name = "mpz-circuits-data"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"bincode 1.3.3",
"mpz-circuits-core",
@@ -4292,7 +4292,7 @@ dependencies = [
[[package]]
name = "mpz-cointoss"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"futures",
"mpz-cointoss-core",
@@ -4305,7 +4305,7 @@ dependencies = [
[[package]]
name = "mpz-cointoss-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"mpz-core",
"opaque-debug",
@@ -4316,7 +4316,7 @@ dependencies = [
[[package]]
name = "mpz-common"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"bytes",
@@ -4336,7 +4336,7 @@ dependencies = [
[[package]]
name = "mpz-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"aes 0.9.0-rc.1",
"bcs",
@@ -4362,7 +4362,7 @@ dependencies = [
[[package]]
name = "mpz-fields"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"ark-ff 0.4.2",
"ark-secp256r1",
@@ -4382,7 +4382,7 @@ dependencies = [
[[package]]
name = "mpz-garble"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"derive_builder 0.11.2",
@@ -4408,7 +4408,7 @@ dependencies = [
[[package]]
name = "mpz-garble-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"aes 0.9.0-rc.1",
"bitvec",
@@ -4439,7 +4439,7 @@ dependencies = [
[[package]]
name = "mpz-hash"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"blake3",
"itybity 0.3.1",
@@ -4452,7 +4452,7 @@ dependencies = [
[[package]]
name = "mpz-ideal-vm"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"futures",
@@ -4469,7 +4469,7 @@ dependencies = [
[[package]]
name = "mpz-memory-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"blake3",
"futures",
@@ -4484,7 +4484,7 @@ dependencies = [
[[package]]
name = "mpz-ole"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"futures",
@@ -4502,7 +4502,7 @@ dependencies = [
[[package]]
name = "mpz-ole-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"hybrid-array",
"itybity 0.3.1",
@@ -4518,7 +4518,7 @@ dependencies = [
[[package]]
name = "mpz-ot"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"cfg-if",
@@ -4541,7 +4541,7 @@ dependencies = [
[[package]]
name = "mpz-ot-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"aes 0.9.0-rc.1",
"blake3",
@@ -4572,7 +4572,7 @@ dependencies = [
[[package]]
name = "mpz-share-conversion"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"mpz-common",
@@ -4588,7 +4588,7 @@ dependencies = [
[[package]]
name = "mpz-share-conversion-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"mpz-common",
"mpz-core",
@@ -4602,7 +4602,7 @@ dependencies = [
[[package]]
name = "mpz-vm-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"futures",
@@ -4615,7 +4615,7 @@ dependencies = [
[[package]]
name = "mpz-zk"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"async-trait",
"blake3",
@@ -4633,7 +4633,7 @@ dependencies = [
[[package]]
name = "mpz-zk-core"
version = "0.1.0-alpha.4"
source = "git+https://github.com/privacy-ethereum/mpz?rev=bd80826#bd808262ecb010ca7b162633e4582a897a2fac12"
source = "git+https://github.com/privacy-ethereum/mpz?rev=5250a78#5250a787b8f6783aac9c104e5e7ce307eb50725a"
dependencies = [
"blake3",
"cfg-if",
@@ -7275,7 +7275,11 @@ dependencies = [
"generic-array",
"hex",
"itybity 0.2.1",
"mpz-circuits",
"opaque-debug",
"pest",
"pest_derive",
"pest_meta",
"rand 0.9.2",
"rand_chacha 0.9.0",
"rand_core 0.9.3",

View File

@@ -66,21 +66,20 @@ tlsn-harness-runner = { path = "crates/harness/runner" }
tlsn-wasm = { path = "crates/wasm" }
tlsn = { path = "crates/tlsn" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-circuits-data = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "bd80826" }
mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
mpz-ideal-vm = { git = "https://github.com/privacy-ethereum/mpz", rev = "5250a78" }
rangeset = { version = "0.2" }
serio = { version = "0.2" }

View File

@@ -27,6 +27,12 @@ tlsn-data-fixtures = { workspace = true, optional = true }
tlsn-tls-core = { workspace = true, features = ["serde"] }
tlsn-utils = { workspace = true }
rangeset = { workspace = true, features = ["serde"] }
mpz-circuits = { workspace = true }
pest = "*"
pest_derive = "*"
pest_meta = "*"
aead = { workspace = true, features = ["alloc"], optional = true }
aes-gcm = { workspace = true, optional = true }

159
crates/core/src/grammar.rs Normal file
View File

@@ -0,0 +1,159 @@
use crate::predicates::Pred;
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pest_derive::Parser;
#[derive(Parser)]
#[grammar = "expr.pest"]
struct ExprParser;
fn parse_expr(input: &str) -> Result<Pred, pest::error::Error<Rule>> {
let mut pairs = ExprParser::parse(Rule::expr, input)?;
Ok(build_expr(pairs.next().unwrap()))
}
fn build_expr(pair: Pair<Rule>) -> Pred {
match pair.as_rule() {
Rule::expr | Rule::or_expr => build_left_assoc(pair.into_inner(), Rule::and_expr, Pred::Or),
Rule::and_expr => build_left_assoc(pair.into_inner(), Rule::not_expr, Pred::And),
Rule::not_expr => {
// NOT* cmp
let mut inner = pair.into_inner(); // possibly multiple NOT then a cmp
// Count NOTs, then parse cmp
let mut not_count = 0;
let mut rest = Vec::new();
for p in inner {
match p.as_rule() {
Rule::NOT => not_count += 1,
_ => {
rest.push(p);
}
}
}
let mut node = build_cmp(rest.into_iter().next().expect("cmp missing"));
if not_count % 2 == 1 {
node = Pred::Not(Box::new(node));
}
node
}
Rule::cmp => build_cmp(pair),
Rule::primary => build_expr(pair.into_inner().next().unwrap()),
Rule::paren => build_expr(pair.into_inner().next().unwrap()),
_ => unreachable!("unexpected rule: {:?}", pair.as_rule()),
}
}
fn build_left_assoc(
mut inner: Pairs<Rule>,
unit_rule: Rule,
mk_node: impl Fn(Vec<Pred>) -> Pred,
) -> Pred {
// pattern: unit (OP unit)*
let mut nodes = Vec::new();
// First unit
if let Some(first) = inner.next() {
assert_eq!(first.as_rule(), unit_rule);
nodes.push(build_expr(first));
}
// Remaining are: OP unit pairs; we only collect the units and wrap later.
while let Some(next) = inner.next() {
// next is the operator token pair (AND/OR), skip it
// then the unit:
if let Some(unit) = inner.next() {
assert_eq!(unit.as_rule(), unit_rule);
nodes.push(build_expr(unit));
}
}
if nodes.len() == 1 {
nodes.pop().unwrap()
} else {
mk_node(nodes)
}
}
fn build_cmp(pair: Pair<Rule>) -> Pred {
// cmp: primary (cmp_op primary)?
let mut inner = pair.into_inner();
let lhs = inner.next().unwrap();
let lhs_term = parse_term(lhs);
if let Some(op_pair) = inner.next() {
let op = match op_pair.as_str() {
"==" => CmpOp::Eq,
"!=" => CmpOp::Ne,
"<" => CmpOp::Lt,
"<=" => CmpOp::Lte,
">" => CmpOp::Gt,
">=" => CmpOp::Gte,
_ => unreachable!(),
};
let rhs = parse_term(inner.next().unwrap());
// Map to your Atom constraint form (LHS must be x[idx]):
let (index, rhs_val) = match (lhs_term, rhs) {
(Term::Idx(i), Term::Const(c)) => (i, Rhs::Const(c)),
(Term::Idx(i1), Term::Idx(i2)) => (i1, Rhs::Idx(i2)),
// If you want to allow const OP idx or const OP const, handle here (flip, etc.)
other => panic!("unsupported comparison pattern: {:?}", other),
};
Pred::Atom(Atom {
index,
op,
rhs: rhs_val,
})
} else {
// A bare primary is treated as a boolean atom; you can decide policy.
// Here we treat "x[i]" as (x[i] != 0) and const as (const != 0).
match lhs_term {
Term::Idx(i) => Pred::Atom(Atom {
index: i,
op: CmpOp::Ne,
rhs: Rhs::Const(0),
}),
Term::Const(c) => {
if c != 0 {
Pred::Or(vec![])
} else {
Pred::And(vec![])
} // true/false constants if you add Const
}
}
}
}
#[derive(Debug, Clone, Copy)]
enum Term {
Idx(usize),
Const(u8),
}
fn parse_term(pair: Pair<Rule>) -> Term {
match pair.as_rule() {
Rule::atom => parse_term(pair.into_inner().next().unwrap()),
Rule::byte_idx => {
// "x" "[" number "]"
let mut i = pair.into_inner();
let num = i.find(|p| p.as_rule() == Rule::number).unwrap();
Term::Idx(num.as_str().parse::<usize>().unwrap())
}
Rule::byte_const => {
let n = pair.into_inner().next().unwrap(); // number
Term::Const(n.as_str().parse::<u8>().unwrap())
}
Rule::paren => parse_term(pair.into_inner().next().unwrap()),
Rule::primary => parse_term(pair.into_inner().next().unwrap()),
_ => unreachable!("term {:?}", pair.as_rule()),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_and() {
let pred = parse_expr("x[100] < x[300] && x[200] == 2 || ! (x[5] >= 57)").unwrap();
// `pred` is a Pred::Or with an And on the left and a Not on the right,
// with Atoms inside.
}
}

41
crates/core/src/json.pest Normal file
View File

@@ -0,0 +1,41 @@
// pest. The Elegant Parser
// Copyright (c) 2018 Dragoș Tiselice
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.
//! A parser for JSON file.
//!
//! And this is a example for JSON parser.
json = _{ SOI ~ value ~ eoi }
eoi = _{ !ANY }
/// Matches object, e.g.: `{ "foo": "bar" }`
/// Foobar
object = { "{" ~ pair ~ (pair)* ~ "}" | "{" ~ "}" }
pair = { quoted_string ~ ":" ~ value ~ (",")? }
array = { "[" ~ value ~ ("," ~ value)* ~ "]" | "[" ~ "]" }
//////////////////////
/// Matches value, e.g.: `"foo"`, `42`, `true`, `null`, `[]`, `{}`.
//////////////////////
value = _{ quoted_string | number | object | array | bool | null }
quoted_string = _{ "\"" ~ string ~ "\"" }
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ string)? }
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }
number = @{ "-"? ~ int ~ ("." ~ ASCII_DIGIT+ ~ exp? | exp)? }
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }
exp = @{ ("E" | "e") ~ ("+" | "-")? ~ ASCII_DIGIT+ }
bool = { "true" | "false" }
null = { "null" }
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }

760
crates/core/src/json.rs Normal file
View File

@@ -0,0 +1,760 @@
//!
use crate::predicates::Pred;
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pest_derive::Parser;
use pest_meta::{ast, parser as meta, parser::consume_rules, validator};
#[cfg(test)]
mod test {
use core::panic;
use std::cmp::{max, min};
use crate::{
config::prove::ProveConfig,
predicates::{eval_pred, is_unicode, Atom, CmpOp, Compiler, Rhs},
};
use super::*;
use mpz_circuits::{
evaluate,
ops::{all, any},
};
use pest_meta::ast::Expr;
use rangeset::RangeSet;
const MAX_LEN: usize = 999_999;
#[derive(Debug, Clone)]
enum Ex {
RepEx(Rep),
RepExactEx(RepExact),
SeqEx(Seq),
StrEx(Str),
ChoiceEx(Choice),
NegPredEx(NegPred),
OptEx(Opt),
// An expression which must be replaced with a copy of the rule.
NestedEx,
#[allow(non_camel_case_types)]
ASCII_NONZERO_DIGIT,
#[allow(non_camel_case_types)]
ASCII_DIGIT,
#[allow(non_camel_case_types)]
// A single Unicode character
ANY,
#[allow(non_camel_case_types)]
ASCII_HEX_DIGIT,
}
impl Ex {
fn min_len(&self) -> usize {
match self {
Ex::RepEx(e) => 0,
Ex::RepExactEx(e) => e.0 .0.min_len() * e.0 .1 as usize,
Ex::StrEx(e) => e.0.len(),
Ex::SeqEx(e) => e.0.min_len() + e.1.min_len(),
Ex::ChoiceEx(e) => min(e.0.min_len(), e.1.min_len()),
Ex::NegPredEx(e) => 0,
Ex::ASCII_NONZERO_DIGIT => 1,
Ex::ASCII_DIGIT => 1,
Ex::ANY => 1,
Ex::ASCII_HEX_DIGIT => 1,
Ex::OptEx(e) => 0,
Ex::NestedEx => 0,
_ => unimplemented!(),
}
}
fn max_len(&self) -> usize {
match self {
Ex::RepEx(e) => MAX_LEN,
Ex::RepExactEx(e) => e.0 .0.max_len() * e.0 .1 as usize,
Ex::StrEx(e) => e.0.len(),
Ex::SeqEx(e) => e.0.max_len() + e.1.max_len(),
Ex::ChoiceEx(e) => max(e.0.max_len(), e.1.max_len()),
Ex::NegPredEx(e) => 0,
Ex::ASCII_NONZERO_DIGIT => 1,
Ex::ASCII_DIGIT => 1,
Ex::ANY => 4,
Ex::ASCII_HEX_DIGIT => 1,
Ex::OptEx(e) => e.0.max_len(),
Ex::NestedEx => 0,
_ => unimplemented!(),
}
}
}
#[derive(Debug, Clone)]
struct Rep(Box<Ex>);
#[derive(Debug, Clone)]
struct RepExact((Box<Ex>, u32));
#[derive(Debug, Clone)]
struct Str(String);
#[derive(Debug, Clone)]
struct Seq(Box<Ex>, Box<Ex>);
#[derive(Debug, Clone)]
struct Choice(Box<Ex>, Box<Ex>);
#[derive(Debug, Clone)]
struct NegPred(Box<Ex>);
#[derive(Debug, Clone)]
struct Opt(Box<Ex>);
struct Rule {
name: String,
pub ex: Ex,
}
/// Builds the rules, returning the final expression.
fn build_rules(ast_rules: &[ast::Rule]) -> Ex {
let mut rules = Vec::new();
// build from the bottom up
let iter = ast_rules.iter().rev();
for r in iter {
println!("building rule with name {:?}", r.name);
let ex = build_expr(&r.expr, &rules, &r.name, false);
// TODO deal with recursive rules
rules.push(Rule {
name: r.name.clone(),
ex,
});
}
let ex = rules.last().unwrap().ex.clone();
ex
}
/// Builds expression from pest expression.
/// passes in current rule's name to deal with recursion.
/// depth is used to prevent infinite recursion.
fn build_expr(exp: &Expr, rules: &[Rule], this_name: &String, is_nested: bool) -> Ex {
match exp {
Expr::Rep(exp) => {
Ex::RepEx(Rep(Box::new(build_expr(exp, rules, this_name, is_nested))))
}
Expr::RepExact(exp, count) => Ex::RepExactEx(RepExact((
Box::new(build_expr(exp, rules, this_name, is_nested)),
*count,
))),
Expr::Str(str) => Ex::StrEx(Str(str.clone())),
Expr::NegPred(exp) => Ex::NegPredEx(NegPred(Box::new(build_expr(
exp, rules, this_name, is_nested,
)))),
Expr::Seq(a, b) => {
//
let a = build_expr(a, rules, this_name, is_nested);
Ex::SeqEx(Seq(
Box::new(a),
Box::new(build_expr(b, rules, this_name, is_nested)),
))
}
Expr::Choice(a, b) => Ex::ChoiceEx(Choice(
Box::new(build_expr(a, rules, this_name, is_nested)),
Box::new(build_expr(b, rules, this_name, is_nested)),
)),
Expr::Opt(exp) => {
Ex::OptEx(Opt(Box::new(build_expr(exp, rules, this_name, is_nested))))
}
Expr::Ident(ident) => {
let ex = match ident.as_str() {
"ASCII_NONZERO_DIGIT" => Ex::ASCII_NONZERO_DIGIT,
"ASCII_DIGIT" => Ex::ASCII_DIGIT,
"ANY" => Ex::ANY,
"ASCII_HEX_DIGIT" => Ex::ASCII_HEX_DIGIT,
_ => {
if *ident == *this_name {
return Ex::NestedEx;
}
for rule in rules {
return rule.ex.clone();
}
panic!("couldnt find rule {:?}", ident);
}
};
ex
}
_ => unimplemented!(),
}
}
// This method must be called when we know that there is enough
// data remained starting from the offset to match the expression
// at least once.
//
// returns the predicate and the offset from which the next expression
// should be matched.
// Returns multiple predicates if the expression caused multiple branches.
// A top level expr always returns a single predicate, in which all branches
// are coalesced.
fn expr_to_pred(
exp: &Ex,
offset: usize,
data_len: usize,
is_top_level: bool,
) -> Vec<(Pred, usize)> {
// if is_top_level {
// println!("top level exps {:?}", exp);
// } else {
// println!("Non-top level exps {:?}", exp);
// }
match exp {
Ex::SeqEx(s) => {
let a = &s.0;
let b = &s.1;
if is_top_level && (offset + a.max_len() + b.max_len() < data_len) {
panic!();
}
if offset + a.min_len() + b.min_len() > data_len {
panic!();
}
// The first expression must not try to match in the
// data of the next expression
let pred1 = expr_to_pred(a, offset, data_len - b.min_len(), false);
// interlace all branches
let mut interlaced = Vec::new();
for (p1, offset) in pred1.iter() {
// if the seq expr was top-level, the 2nd expr becomes top-level
let mut pred2 = expr_to_pred(b, *offset, data_len, is_top_level);
for (p2, offser_inner) in pred2.iter() {
let pred = Pred::And(vec![p1.clone(), p2.clone()]);
interlaced.push((pred, *offser_inner));
}
}
if is_top_level {
// coalesce all branches
let preds: Vec<Pred> = interlaced.into_iter().map(|(a, _b)| a).collect();
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
vec![(Pred::Or(preds), 0)]
}
} else {
interlaced
}
}
Ex::ChoiceEx(s) => {
let a = &s.0;
let b = &s.1;
let mut skip_a = false;
let mut skip_b = false;
if is_top_level {
if offset + a.max_len() != data_len {
skip_a = true
}
if offset + b.max_len() != data_len {
skip_b = true;
}
} else {
// if not top level, we may skip an expression when it will
// overflow the data len
if offset + a.min_len() > data_len {
skip_a = true
}
if offset + b.min_len() > data_len {
skip_b = true
}
}
if skip_a && skip_b {
panic!();
}
let mut preds_a = Vec::new();
let mut preds_b = Vec::new();
if !skip_a {
preds_a = expr_to_pred(a, offset, data_len, is_top_level);
}
if !skip_b {
preds_b = expr_to_pred(b, offset, data_len, is_top_level);
}
// combine all branches
let mut combined = Vec::new();
if preds_a.is_empty() {
combined = preds_b.clone();
} else if preds_b.is_empty() {
combined = preds_a.clone();
} else {
assert!(!(preds_a.is_empty() && preds_b.is_empty()));
combined.append(&mut preds_a);
combined.append(&mut preds_b);
}
if is_top_level {
// coalesce all branches
let preds: Vec<Pred> = combined.into_iter().map(|(a, _b)| a).collect();
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
vec![(Pred::Or(preds), 0)]
}
} else {
combined
}
}
Ex::RepEx(r) => {
let e = &r.0;
if offset + e.min_len() > data_len {
if is_top_level {
panic!();
}
// zero matches
return vec![];
}
let mut interlaced = Vec::new();
let mut preds = expr_to_pred(&e, offset, data_len, false);
// for (i, (pred, depth)) in preds.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
// Append single matches.
interlaced.append(&mut preds.clone());
let mut was_found = true;
while was_found {
was_found = false;
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
if offset_outer + e.min_len() > data_len {
// cannot match any more
continue;
}
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
// for (i, (pred, depth)) in preds_inner.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
for (pred_inner, offset_inner) in preds_inner {
let pred = (
Pred::And(vec![pred_outer.clone(), pred_inner]),
offset_inner,
);
preds.push(pred);
was_found = true;
}
}
interlaced.append(&mut preds.clone());
}
// for (i, (pred, depth)) in interlaced.iter().enumerate() {
// println!("preds[{i}] (depth {depth}):");
// println!("{pred}");
// }
if is_top_level {
// drop all branches which do not match exactly at the data length
// border and coalesce the rest
let preds: Vec<Pred> = interlaced
.into_iter()
.filter(|(_a, b)| *b == data_len)
.map(|(a, _b)| a)
.collect();
if preds.is_empty() {
panic!()
}
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
// coalesce all branches
vec![(Pred::Or(preds), 0)]
}
} else {
interlaced
}
}
Ex::RepExactEx(r) => {
let e = &r.0 .0;
let count = r.0 .1;
assert!(count > 0);
if is_top_level && (offset + e.max_len() * count as usize <= data_len) {
panic!();
}
let mut preds = expr_to_pred(&e, offset, data_len, false);
for i in 1..count {
for (pred_outer, offset_outer) in std::mem::take(&mut preds).into_iter() {
if offset_outer + e.min_len() > data_len {
// cannot match any more
continue;
}
let mut preds_inner = expr_to_pred(&e, offset_outer, data_len, false);
for (pred_inner, offset_inner) in preds_inner {
let pred = (
Pred::And(vec![pred_outer.clone(), pred_inner]),
offset_inner,
);
preds.push(pred);
}
}
}
if is_top_level {
// drop all branches which do not match exactly at the data length
// border and coalesce the rest
let preds: Vec<Pred> = preds
.into_iter()
.filter(|(_a, b)| *b != data_len)
.map(|(a, _b)| a)
.collect();
if preds.is_empty() {
panic!()
}
if preds.len() == 1 {
vec![(preds[0].clone(), 0)]
} else {
// coalesce all branches
vec![(Pred::Or(preds), 0)]
}
} else {
preds
}
}
Ex::NegPredEx(e) => {
assert!(offset <= data_len);
if offset == data_len {
// the internal expression cannot be match since there is no data left,
// this means that the negative expression matched
if is_top_level {
panic!("always true predicate doesnt make sense")
}
// TODO this is hacky.
return vec![(Pred::True, offset)];
}
let e = &e.0;
let preds = expr_to_pred(&e, offset, data_len, is_top_level);
let preds: Vec<Pred> = preds.into_iter().map(|(a, _b)| a).collect();
let len = preds.len();
// coalesce all branches, offset doesnt matter since those
// offset will never be used anymore.
let pred = if preds.len() == 1 {
Pred::Not(Box::new(preds[0].clone()))
} else {
Pred::Not(Box::new(Pred::Or(preds)))
};
if is_top_level && len == 0 {
panic!()
}
// all offset if negative predicate are ignored since no matching
// will be done from those offsets.
vec![(pred, offset)]
}
Ex::OptEx(e) => {
let e = &e.0;
if is_top_level {
return vec![(Pred::True, 0)];
}
// add an always-matching branch
let mut preds = vec![(Pred::True, offset)];
if e.min_len() + offset <= data_len {
// try to match only if there is enough data
let mut p = expr_to_pred(&e, offset, data_len, is_top_level);
preds.append(&mut p);
}
preds
}
Ex::StrEx(s) => {
if is_top_level && offset + s.0.len() != data_len {
panic!();
}
let mut preds = Vec::new();
for (idx, byte) in s.0.clone().into_bytes().iter().enumerate() {
let a = Atom {
index: offset + idx,
op: CmpOp::Eq,
rhs: Rhs::Const(*byte),
};
preds.push(Pred::Atom(a));
}
if preds.len() == 1 {
vec![(preds[0].clone(), offset + s.0.len())]
} else {
vec![(Pred::And(preds), offset + s.0.len())]
}
}
Ex::ASCII_NONZERO_DIGIT => {
if is_top_level && (offset + 1 != data_len) {
panic!();
}
let gte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Gte,
rhs: Rhs::Const(49u8),
});
let lte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
vec![(Pred::And(vec![gte, lte]), offset + 1)]
}
Ex::ASCII_DIGIT => {
if is_top_level && (offset + 1 != data_len) {
panic!();
}
let gte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: offset,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
vec![(Pred::And(vec![gte, lte]), offset + 1)]
}
Ex::ANY => {
if is_top_level && (offset + 1 > data_len) {
panic!();
}
let start = offset;
let end = min(offset + 4, data_len);
let mut branches = Vec::new();
for branch_end in start + 1..end {
branches.push((is_unicode(RangeSet::from(start..branch_end)), branch_end))
}
if is_top_level {
assert!(branches.len() == 1);
}
branches
}
_ => unimplemented!(),
}
}
#[test]
fn test_json_int() {
use rand::{distr::Alphanumeric, rng, Rng};
let grammar = include_str!("json_int.pest");
// Parse the grammar file into Pairs (the grammars own parse tree)
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
// 4) Convert the parsed pairs into the stable AST representation
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
let exp = build_rules(&rules_ast);
// 5) Inspect the AST however you like For a quick look, the Debug print is the
// safest (works across versions)
for rule in &rules_ast {
println!("{:#?}", rule);
}
const LENGTH: usize = 7; // Adjustable constant
let pred = expr_to_pred(&exp, 0, LENGTH, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
let circ = Compiler::new().compile(&pred);
println!("{:?} and gates", circ.and_count());
for i in 0..1000000 {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(LENGTH)
.map(char::from)
.collect();
let out = eval_pred(pred, s.as_bytes());
let is_int = s.chars().all(|c| c.is_ascii_digit()) && !s.starts_with('0');
if out != is_int {
println!("failed at index {:?} with {:?}", i, s);
}
assert_eq!(out, is_int)
}
}
#[test]
fn test_json_str() {
use rand::{distr::Alphanumeric, rng, Rng};
const LENGTH: usize = 10; // Adjustable constant
let grammar = include_str!("json_str.pest");
// Parse the grammar file into Pairs (the grammars own parse tree)
let pairs = meta::parse(meta::Rule::grammar_rules, grammar).expect("grammar parse error");
// Optional: validate (reports duplicate rules, unreachable rules, etc.)
validator::validate_pairs(pairs.clone()).expect("invalid grammar");
// 4) Convert the parsed pairs into the stable AST representation
let rules_ast: Vec<ast::Rule> = consume_rules(pairs).unwrap();
for rule in &rules_ast {
println!("{:#?}", rule);
}
let exp = build_rules(&rules_ast);
for len in LENGTH..LENGTH + 7 {
let pred = expr_to_pred(&exp, 0, len, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
let circ = Compiler::new().compile(pred);
println!(
"JSON string length: {:?}; circuit AND gate count {:?}",
len,
circ.and_count()
);
}
}
#[test]
fn test_choice() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Choice(Box::new(a), Box::new(b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 1, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_seq() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Seq(Box::new(a), Box::new(b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 2, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_rep() {
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Rep(Box::new(a)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 3, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
}
#[test]
fn test_rep_choice() {
const LENGTH: usize = 5; // Adjustable constant
let a = Expr::Ident("ASCII_NONZERO_DIGIT".to_string());
let b = Expr::Ident("ASCII_DIGIT".to_string());
// Number of predicates needed to represent the expressions.
let a_weight = 2usize;
let b_weight = 2usize;
let rep_a = Expr::Rep(Box::new(a));
let rep_b = Expr::Rep(Box::new(b));
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Choice(Box::new(rep_a), Box::new(rep_b)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, LENGTH, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {}", pred);
// This is for sanity that no extra predicates are being added.
assert_eq!(pred.leaves(), a_weight * LENGTH + b_weight * LENGTH);
}
#[test]
fn test_neg_choice() {
let a = Expr::Str("4".to_string());
let b = Expr::Str("5".to_string());
let choice = Expr::Choice(Box::new(a), Box::new(b));
let neg_choice = Expr::NegPred(Box::new(choice));
let c = Expr::Str("a".to_string());
let d = Expr::Str("BC".to_string());
let choice2 = Expr::Choice(Box::new(c), Box::new(d));
let rule = ast::Rule {
name: "test".to_string(),
ty: ast::RuleType::Atomic,
expr: Expr::Seq(Box::new(neg_choice), Box::new(choice2)),
};
let exp = build_rules(&vec![rule]);
let pred = expr_to_pred(&exp, 0, 2, true);
assert!(pred.len() == 1);
let pred = &pred[0].0;
println!("pred is {:?}", pred);
assert_eq!(pred.leaves(), 4);
}
}

View File

@@ -0,0 +1,3 @@
// Copied from pest.json
int = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* }

View File

@@ -0,0 +1,6 @@
// Copied from pest.json
// Replaced "string" with "X" to avoid recursion.
string = @{ (!("\"" | "\\") ~ ANY)* ~ (escape ~ "X")? }
escape = @{ "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode) }
unicode = @{ "u" ~ ASCII_HEX_DIGIT{4} }

View File

@@ -14,6 +14,9 @@ pub mod webpki;
pub use rangeset;
pub mod config;
pub(crate) mod display;
//pub mod grammar;
pub mod json;
pub mod predicates;
use serde::{Deserialize, Serialize};

View File

@@ -0,0 +1,790 @@
//! Predicate and compiler.
use std::{collections::HashMap, fmt};
use mpz_circuits::{itybity::ToBits, ops, Circuit, CircuitBuilder, Feed, Node};
use rangeset::RangeSet;
/// ddd
#[derive(Debug, Clone)]
pub(crate) enum Pred {
And(Vec<Pred>),
Or(Vec<Pred>),
Not(Box<Pred>),
Atom(Atom),
// An always-true predicate.
True,
// An always-false predicate.
False,
}
impl Pred {
/// Returns sorted unique byte indices of this predicate.
pub(crate) fn indices(&self) -> Vec<usize> {
let mut indices = self.indices_internal(self);
indices.sort_unstable();
indices.dedup();
indices
}
// Returns the number of leaves (i.e atoms) the AST of this predicate has.
pub(crate) fn leaves(&self) -> usize {
match self {
Pred::And(vec) => vec.iter().map(|p| p.leaves()).sum(),
Pred::Or(vec) => vec.iter().map(|p| p.leaves()).sum(),
Pred::Not(p) => p.leaves(),
Pred::Atom(atom) => 1,
Pred::True => 0,
Pred::False => 0,
}
}
/// Returns all byte indices of the given `pred`icate.
fn indices_internal(&self, pred: &Pred) -> Vec<usize> {
match pred {
Pred::And(vec) => vec
.iter()
.flat_map(|p| self.indices_internal(p))
.collect::<Vec<_>>(),
Pred::Or(vec) => vec
.iter()
.flat_map(|p| self.indices_internal(p))
.collect::<Vec<_>>(),
Pred::Not(p) => self.indices_internal(p),
Pred::Atom(atom) => {
let mut indices = Vec::new();
indices.push(atom.index);
if let Rhs::Idx(idx) = atom.rhs {
indices.push(idx);
}
indices
}
Pred::True => vec![],
Pred::False => vec![],
}
}
}
impl fmt::Display for Pred {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with_indent(f, 0)
}
}
impl Pred {
fn fmt_with_indent(&self, f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
// helper to write the current indentation
fn pad(f: &mut fmt::Formatter<'_>, indent: usize) -> fmt::Result {
// 2 spaces per level; tweak as you like
write!(f, "{:indent$}", "", indent = indent * 2)
}
match self {
Pred::And(preds) => {
pad(f, indent)?;
writeln!(f, "And(")?;
for p in preds {
p.fmt_with_indent(f, indent + 1)?;
}
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Or(preds) => {
pad(f, indent)?;
writeln!(f, "Or(")?;
for p in preds {
p.fmt_with_indent(f, indent + 1)?;
}
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Not(p) => {
pad(f, indent)?;
writeln!(f, "Not(")?;
p.fmt_with_indent(f, indent + 1)?;
pad(f, indent)?;
writeln!(f, ")")
}
Pred::Atom(a) => {
pad(f, indent)?;
writeln!(f, "Atom({:?})", a)
}
Pred::True => {
pad(f, indent)?;
writeln!(f, "True")
}
Pred::False => {
pad(f, indent)?;
writeln!(f, "False")
}
}
}
}
/// Atomic predicate of the form:
/// x[index] (op) rhs
#[derive(Debug, Clone)]
pub struct Atom {
/// Left-hand side byte index `i` (x_i).
pub index: usize,
/// Comparison operator.
pub op: CmpOp,
/// Right-hand side operand (constant or x_j).
pub rhs: Rhs,
}
/// ddd
#[derive(Debug, Clone)]
pub(crate) enum CmpOp {
Eq, // ==
Ne, // !=
Gt, // >
Gte, // >=
Lt, // <
Lte, // <=
}
/// RHS of a comparison
#[derive(Debug, Clone)]
pub enum Rhs {
/// Byte at index
Idx(usize),
/// Literal constant.
Const(u8),
}
/// Compiles a predicate into a circuit.
pub struct Compiler {
/// A <byte index, circuit feeds> map.
map: HashMap<usize, [Node<Feed>; 8]>,
}
impl Compiler {
pub(crate) fn new() -> Self {
Self {
map: HashMap::new(),
}
}
/// Compiles the given predicate into a circuit, consuming the
/// compiler.
pub(crate) fn compile(&mut self, pred: &Pred) -> Circuit {
let mut builder = CircuitBuilder::new();
for idx in pred.indices() {
let feeds = (0..8).map(|_| builder.add_input()).collect::<Vec<_>>();
self.map.insert(idx, feeds.try_into().unwrap());
}
let out = self.process(&mut builder, pred);
builder.add_output(out);
builder.build().unwrap()
}
// Processes a single predicate.
fn process(&mut self, builder: &mut CircuitBuilder, pred: &Pred) -> Node<Feed> {
match pred {
Pred::And(vec) => {
let out = vec
.iter()
.map(|p| self.process(builder, p))
.collect::<Vec<_>>();
ops::all(builder, &out)
}
Pred::Or(vec) => {
let out = vec
.iter()
.map(|p| self.process(builder, p))
.collect::<Vec<_>>();
ops::any(builder, &out)
}
Pred::Not(p) => {
let pred_out = self.process(builder, p);
let inv = ops::inv(builder, [pred_out]);
inv[0]
}
Pred::Atom(atom) => {
let lhs = self.map.get(&atom.index).unwrap().clone();
let rhs = match atom.rhs {
Rhs::Const(c) => const_feeds(builder, c),
Rhs::Idx(s) => self.map.get(&s).unwrap().clone(),
};
match atom.op {
CmpOp::Eq => ops::eq(builder, lhs, rhs),
CmpOp::Ne => ops::neq(builder, lhs, rhs),
CmpOp::Lt => ops::lt(builder, lhs, rhs),
CmpOp::Lte => ops::lte(builder, lhs, rhs),
CmpOp::Gt => ops::gt(builder, lhs, rhs),
CmpOp::Gte => ops::gte(builder, lhs, rhs),
}
}
Pred::True => builder.get_const_one(),
Pred::False => builder.get_const_zero(),
}
}
}
// Returns circuit feeds for the given constant u8 value.
fn const_feeds(builder: &CircuitBuilder, cnst: u8) -> [Node<Feed>; 8] {
cnst.iter_lsb0()
.map(|b| {
if b {
builder.get_const_one()
} else {
builder.get_const_zero()
}
})
.collect::<Vec<_>>()
.try_into()
.expect("u8 has 8 feeds")
}
// Evaluates the predicate on the input `data`.
pub(crate) fn eval_pred(pred: &Pred, data: &[u8]) -> bool {
match pred {
Pred::And(vec) => vec.iter().map(|p| eval_pred(p, data)).all(|b| b),
Pred::Or(vec) => vec.iter().map(|p| eval_pred(p, data)).any(|b| b),
Pred::Not(p) => !eval_pred(p, data),
Pred::Atom(atom) => {
let lhs = data[atom.index];
let rhs = match atom.rhs {
Rhs::Const(c) => c,
Rhs::Idx(s) => data[s],
};
match atom.op {
CmpOp::Eq => lhs == rhs,
CmpOp::Ne => lhs != rhs,
CmpOp::Lt => lhs < rhs,
CmpOp::Lte => lhs <= rhs,
CmpOp::Gt => lhs > rhs,
CmpOp::Gte => lhs >= rhs,
}
}
Pred::True => true,
Pred::False => true,
}
}
/// Builds a predicate that an ascii integer is contained in the ranges.
fn is_ascii_integer(range: RangeSet<usize>) -> Pred {
let mut preds = Vec::new();
for idx in range.iter() {
let lte = Pred::Atom(Atom {
index: idx,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
let gte = Pred::Atom(Atom {
index: idx,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
preds.push(Pred::And(vec![lte, gte]));
}
Pred::And(preds)
}
/// Builds a predicate that a valid HTTP header value is contained in the
/// ranges.
fn is_valid_http_header_value(range: RangeSet<usize>) -> Pred {
let mut preds = Vec::new();
for idx in range.iter() {
let ne = Pred::Atom(Atom {
index: idx,
op: CmpOp::Ne,
// ascii code for carriage return \r
rhs: Rhs::Const(13u8),
});
preds.push(ne);
}
Pred::And(preds)
}
/// Builds a predicate that a valid JSON string is contained in the
/// ranges.
fn is_valid_json_string(range: RangeSet<usize>) -> Pred {
assert!(
range.len_ranges() == 1,
"only a contiguous range is allowed"
);
const BACKSLASH: u8 = 92;
// check if all unicode chars are allowed
let mut preds = Vec::new();
// Find all /u sequences
for (i, idx) in range.iter().enumerate() {
if i == range.len() - 1 {
// if this is a last char, skip it
continue;
}
let is_backslash = Pred::Atom(Atom {
index: idx,
op: CmpOp::Eq,
rhs: Rhs::Const(BACKSLASH),
});
}
Pred::And(preds)
}
// Returns a predicate that a unicode char is contained in the range
pub(crate) fn is_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() <= 4);
match range.len() {
1 => is_1_byte_unicode(range.max().unwrap()),
2 => is_2_byte_unicode(range),
3 => is_3_byte_unicode(range),
4 => is_4_byte_unicode(range),
_ => unimplemented!(),
}
}
fn is_1_byte_unicode(pos: usize) -> Pred {
Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(127u8),
})
}
fn is_2_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 2);
let mut iter = range.iter();
// should be 110xxxxx
let first = iter.next().unwrap();
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xC0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xDF),
});
let second = iter.next().unwrap();
Pred::And(vec![lte, gte, is_unicode_continuation(second)])
}
fn is_3_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 3);
let mut iter = range.iter();
let first = iter.next().unwrap();
// should be 1110xxxx
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xE0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xEF),
});
let second = iter.next().unwrap();
let third = iter.next().unwrap();
Pred::And(vec![
lte,
gte,
is_unicode_continuation(second),
is_unicode_continuation(third),
])
}
fn is_4_byte_unicode(range: RangeSet<usize>) -> Pred {
assert!(range.len() == 4);
let mut iter = range.iter();
let first = iter.next().unwrap();
// should be 11110xxx
let gte = Pred::Atom(Atom {
index: first,
op: CmpOp::Gte,
rhs: Rhs::Const(0xF0),
});
let lte = Pred::Atom(Atom {
index: first,
op: CmpOp::Lte,
rhs: Rhs::Const(0xF7),
});
let second = iter.next().unwrap();
let third = iter.next().unwrap();
let fourth = iter.next().unwrap();
Pred::And(vec![
lte,
gte,
is_unicode_continuation(second),
is_unicode_continuation(third),
is_unicode_continuation(fourth),
])
}
fn is_unicode_continuation(pos: usize) -> Pred {
// should be 10xxxxxx
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(0x80),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(0xBF),
});
Pred::And(vec![lte, gte])
}
fn is_ascii_hex_digit(pos: usize) -> Pred {
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
let is_digit = Pred::And(vec![lte, gte]);
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(65u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(70u8),
});
let is_upper = Pred::And(vec![lte, gte]);
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(97u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(102u8),
});
let is_lower = Pred::And(vec![lte, gte]);
Pred::Or(vec![is_digit, is_lower, is_upper])
}
fn is_ascii_lowercase(pos: usize) -> Pred {
let gte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Gte,
rhs: Rhs::Const(48u8),
});
let lte = Pred::Atom(Atom {
index: pos,
op: CmpOp::Lte,
rhs: Rhs::Const(57u8),
});
Pred::And(vec![lte, gte])
}
#[cfg(test)]
mod test {
use super::*;
use mpz_circuits::evaluate;
use rand::rng;
#[test]
fn test_and() {
let pred = Pred::And(vec![
Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
}),
Pred::Atom(Atom {
index: 200,
op: CmpOp::Eq,
rhs: Rhs::Const(2u8),
}),
]);
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [1u8, 2, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3, 3]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_or() {
let pred = Pred::Or(vec![
Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
}),
Pred::Atom(Atom {
index: 200,
op: CmpOp::Eq,
rhs: Rhs::Const(2u8),
}),
]);
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [1u8, 0, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3, 0]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_not() {
let pred = Pred::Not(Box::new(Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
})));
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 3]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
assert_eq!(out, false);
}
// Tests when RHS is a const.
#[test]
fn test_rhs_const() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Const(22u8),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, 23u8).unwrap();
assert_eq!(out, false);
}
// Tests when RHS is an index.
#[test]
fn test_rhs_idx() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(200),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, 5u8, 10u8).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, 23u8, 5u8).unwrap();
assert_eq!(out, false);
}
// Tests when same index is used in the predicate.
#[test]
fn test_same_idx() {
let pred1 = Pred::Atom(Atom {
index: 100,
op: CmpOp::Eq,
rhs: Rhs::Idx(100),
});
let pred2 = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(100),
});
let circ = Compiler::new().compile(&pred1);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, true);
let circ = Compiler::new().compile(&pred2);
let out: bool = evaluate!(circ, 5u8).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_eq() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Eq,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 5]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 3]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_neq() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Ne,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [5u8, 6]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_gt() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Gt,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [7u8, 6]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [1u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_gte() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Gte,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [7u8, 7]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [0u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_lt() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lt,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [2u8, 7]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_atom_lte() {
let pred = Pred::Atom(Atom {
index: 100,
op: CmpOp::Lte,
rhs: Rhs::Idx(300),
});
let circ = Compiler::new().compile(&pred);
let out: bool = evaluate!(circ, [2u8, 2]).unwrap();
assert_eq!(out, true);
let out: bool = evaluate!(circ, [4u8, 1]).unwrap();
assert_eq!(out, false);
}
#[test]
fn test_is_ascii_integer() {
let text = "text with integers 123456 text";
let pos = text.find("123456").unwrap();
let pred = is_ascii_integer(RangeSet::from(pos..pos + 6));
let bytes: &[u8] = text.as_bytes();
let out = eval_pred(&pred, bytes);
assert_eq!(out, true);
let out = eval_pred(&pred, &[&[0u8], bytes].concat());
assert_eq!(out, false);
}
#[test]
fn test_is_valid_http_header_value() {
let valid = "valid header value";
let invalid = "invalid header \r value";
let pred = is_valid_http_header_value(RangeSet::from(0..valid.len()));
let out: bool = eval_pred(&pred, valid.as_bytes());
assert_eq!(out, true);
let pred = is_valid_http_header_value(RangeSet::from(0..invalid.len()));
let out = eval_pred(&pred, invalid.as_bytes());
assert_eq!(out, false);
}
#[test]
fn test_is_unicode() {
use rand::{distr::Alphanumeric, rng, Rng};
let mut rng = rng();
for _ in 0..1000000 {
let mut s = String::from("HelloWorld");
let insert_pos = 5; // logical character index (after "Hello")
let byte_index = s
.char_indices()
.nth(insert_pos)
.map(|(i, _)| i)
.unwrap_or_else(|| s.len());
// Pick a random Unicode scalar value (0x0000..=0x10FFFF)
// Retry if it's in the surrogate range (U+D800..=U+DFFF)
let c = loop {
let code = rng.random_range(0x0000u32..=0x10FFFF);
if !(0xD800..=0xDFFF).contains(&code) {
if let Some(ch) = char::from_u32(code) {
break ch;
}
}
};
let mut buf = [0u8; 4]; // max UTF-8 length
let encoded = c.encode_utf8(&mut buf); // returns &str
let len = encoded.len();
s.insert_str(byte_index, &c.to_string());
let pred = is_unicode(RangeSet::from(byte_index..byte_index + len));
let out = eval_pred(&pred, s.as_bytes());
assert_eq!(out, true);
}
let bad_unicode = 255u8;
let pred = is_unicode(RangeSet::from(0..1));
let out = eval_pred(&pred, &[bad_unicode]);
assert_eq!(out, false);
}
}