reduce traversal memory usage through iterators (#209)

Previously, the traversal stack held (node, children queued) pairs.
When visiting a node without it's children queued, we would queue them
all. They take a lot of memory!

Now, the stack holds children iterators.

Also: this patch fixes many bugs introduced by the prior one.
This commit is contained in:
Alex Ozdemir
2024-09-29 13:27:26 -07:00
committed by GitHub
parent d3e1f6817e
commit 8140b1369e
17 changed files with 394 additions and 160 deletions

View File

@@ -44,7 +44,7 @@ use circ::target::r1cs::{
proof::{CommitProofSystem, ProofSystem},
};
#[cfg(feature = "r1cs")]
use circ::target::r1cs::{opt::reduce_linearities, trans::to_r1cs};
use circ::target::r1cs::{opt::reduce_linearities, trans::to_r1cs, R1csStats};
#[cfg(feature = "smt")]
use circ::target::smt::find_model;
use circ_fields::FieldT;
@@ -311,19 +311,26 @@ fn main() {
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
let mut r1cs = to_r1cs(cs, cfg());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
println!("Pre-opt r1cs stats: {:#?}", r1cs.stats());
}
println!("Running r1cs optimizations ");
r1cs = reduce_linearities(r1cs, cfg());
println!("Final R1cs size: {}", r1cs.constraints().len());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
println!("Post-opt r1cs stats: {:#?}", r1cs.stats());
}
let n_constraints = r1cs.stats().n_constraints;
let n_vars = r1cs.stats().n_vars;
let n_entries = r1cs.stats().n_entries();
let (prover_data, verifier_data) = r1cs.finalize(cs);
println!(
"Final R1cs rounds: {}",
"Final r1cs: {} constraints, {} variables, {} entries, {} rounds",
n_constraints,
n_vars,
n_entries,
prover_data.precompute.stage_sizes().count() - 1
);
println!(
@@ -331,6 +338,7 @@ fn main() {
prover_data.precompute.num_steps(),
prover_data.precompute.num_step_args()
);
match action {
ProofAction::Count => (),
#[cfg(feature = "bellman")]

View File

@@ -1,4 +1,4 @@
non_zero(Z: field) :- ((Z + Y) + Y) + ~!-Y + to_field(Y);
main(Z: field) :- ((Z + Y) + Y) + ~!-Y + to_field(Y);
Z;
exists A: field. A * Z = 1;
exists B: field, C: bool. B * C * Z = 4.

View File

@@ -140,7 +140,7 @@ fn main() {
);
reduce_linearities(r1cs, cfg())
};
println!("Final R1cs size: {}", r1cs.constraints().len());
println!("Final r1cs: {} constraints", r1cs.constraints().len());
match action {
ProofAction::Count => {
if !options.quiet {

View File

@@ -55,7 +55,7 @@ function cs_count_test {
cs_upper_bound=$2
rm -rf P V pi
output=$($BIN $ex_name r1cs --action count |& cat)
n_constraints=$(echo "$output" | grep 'Final R1cs size:' | grep -Eo '\b[0-9]+\b')
n_constraints=$(echo "$output" | grep -E 'Final r1cs: [0-9]+' -o | grep -Eo '\b[0-9]+\b')
[[ $n_constraints -lt $cs_upper_bound ]] || (echo "Got $n_constraints, expected < $cs_upper_bound" && exit 1)
}

View File

@@ -28,7 +28,7 @@ function r1cs_test_count {
zpath=$1
threshold=$2
o=$($BIN --field-custom-modulus $modulus $zpath r1cs --action count)
n_constraints=$(echo $o | grep 'Final R1cs size:' | grep -Eo '\b[0-9]+\b')
n_constraints=$(echo $o | grep -E 'Final r1cs: [0-9]+' -o | grep -Eo '\b[0-9]+\b')
[[ $n_constraints -lt $threshold ]] || (echo "Got $n_constraints, expected < $threshold" && exit 1)
}

View File

@@ -6,20 +6,24 @@ disable -r time
BIN=./target/debug/examples/circ
$BIN --language datalog ./examples/datalog/parse_test/one_rule.pl r1cs --action count || true
$BIN --language datalog ./examples/datalog/inv.pl r1cs --action count || true
$BIN --language datalog ./examples/datalog/call.pl r1cs --action count || true
$BIN --language datalog ./examples/datalog/arr.pl r1cs --action count || true
function getconstraints {
grep -E "Final r1cs: .* constraints" -o | grep -E -o "\\b[0-9]+"
}
# Small R1cs b/c too little recursion.
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 4 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 4 r1cs --action count || true) | getconstraints)
[ "$size" -lt 10 ]
# Big R1cs b/c enough recursion
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 5 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 5 r1cs --action count || true) | getconstraints)
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 10 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl --datalog-rec-limit 10 r1cs --action count || true) | getconstraints)
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dec.pl --datalog-rec-limit 2 r1cs --action count || true) | grep -E "Final R1cs size:" | grep -E -o "\\b[0-9]+")
size=$(($BIN --language datalog ./examples/datalog/dec.pl --datalog-rec-limit 2 r1cs --action count || true) | getconstraints)
[ "$size" -gt 250 ]
# Test prim-rec test

View File

@@ -29,7 +29,7 @@ function r1cs_test_count {
zpath=$1
threshold=$2
o=$($BIN $zpath r1cs --action count)
n_constraints=$(echo $o | grep 'Final R1cs size:' | grep -Eo '\b[0-9]+\b')
n_constraints=$(echo $o | grep -E 'Final r1cs: [0-9]+' -o | grep -Eo '\b[0-9]+\b')
[[ $n_constraints -lt $threshold ]] || (echo "Got $n_constraints, expected < $threshold" && exit 1)
}

View File

@@ -20,8 +20,8 @@ use std::collections::HashMap;
use std::fmt::Display;
use std::path::PathBuf;
use std::str::FromStr;
use zokrates_pest_ast as ast;
use std::time;
use zokrates_pest_ast as ast;
use term::*;
use zvisit::{ZConstLiteralRewriter, ZGenericInf, ZStatementWalker, ZVisitorMut};
@@ -37,7 +37,6 @@ pub struct Inputs {
pub mode: Mode,
}
#[allow(dead_code)]
fn const_value_simple(term: &Term) -> Option<Value> {
match term.op() {
@@ -48,10 +47,10 @@ fn const_value_simple(term: &Term) -> Option<Value> {
#[allow(dead_code)]
fn const_bool_simple(t: T) -> Option<bool> {
match const_value_simple(&t.term) {
Some(Value::Bool(b)) => Some(b),
_ => None
}
match const_value_simple(&t.term) {
Some(Value::Bool(b)) => Some(b),
_ => None,
}
}
#[allow(dead_code)]
@@ -132,7 +131,7 @@ struct ZGen<'ast> {
}
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
struct FnCallImplInput(bool, Vec<T>, Vec<(String,T)>, PathBuf, String);
struct FnCallImplInput(bool, Vec<T>, Vec<(String, T)>, PathBuf, String);
impl<'ast> Drop for ZGen<'ast> {
fn drop(&mut self) {
@@ -271,7 +270,7 @@ impl<'ast> ZGen<'ast> {
args.len(),
f_name
))
} else if generics.len() != 0 {
} else if !generics.is_empty() {
Err(format!(
"Got {} generic args to EMBED/{}, expected 0",
generics.len(),
@@ -288,7 +287,7 @@ impl<'ast> ZGen<'ast> {
args.len(),
f_name
))
} else if generics.len() != 0 {
} else if !generics.is_empty() {
Err(format!(
"Got {} generic args to EMBED/{}, expected 0",
generics.len(),
@@ -575,10 +574,9 @@ impl<'ast> ZGen<'ast> {
.map_err(|e| format!("{e}"))?
.unwrap_term()
};
let new =
loc_store(old, &zaccs[..], val)
.map(const_fold)
.and_then(|n| if strict { const_val_simple(n) } else { Ok(n) })?;
let new = loc_store(old, &zaccs[..], val)
.map(const_fold)
.and_then(|n| if strict { const_val_simple(n) } else { Ok(n) })?;
debug!("Assign: {}", name);
if IS_CNST {
self.cvar_assign(name, new)
@@ -715,9 +713,9 @@ impl<'ast> ZGen<'ast> {
egv.iter()
.map(|cgv| match cgv {
ast::ConstantGenericValue::Value(l) => self.literal_(l),
ast::ConstantGenericValue::Identifier(i) => {
self.identifier_impl_::<IS_CNST>(i).and_then(const_val_simple)
}
ast::ConstantGenericValue::Identifier(i) => self
.identifier_impl_::<IS_CNST>(i)
.and_then(const_val_simple),
ast::ConstantGenericValue::Underscore(_) => Err(
"explicit_generic_values got non-monomorphized generic argument".to_string(),
),
@@ -727,7 +725,6 @@ impl<'ast> ZGen<'ast> {
.collect()
}
fn function_call_impl_<const IS_CNST: bool>(
&self,
args: Vec<T>,
@@ -752,21 +749,34 @@ impl<'ast> ZGen<'ast> {
.unify_generic(egv, exp_ty, arg_tys)?;
let mut generic_vec = generics.clone().into_iter().collect::<Vec<_>>();
generic_vec.sort_by(|(a,_), (b,_)| a.cmp(&b));
generic_vec.sort_by(|(a, _), (b, _)| a.cmp(b));
let before = time::Instant::now();
let input = FnCallImplInput(IS_CNST, args.clone(), generic_vec.clone(), f_path.clone(), f_name.clone());
let input = FnCallImplInput(
IS_CNST,
args.clone(),
generic_vec.clone(),
f_path.clone(),
f_name.clone(),
);
let cached_value = self.fn_call_memoization.borrow().get(&input).cloned();
let ret = if let Some(value) = cached_value {
Ok(value)
} else {
debug!("successfully memoized {} {:?}", f_name, f_path);
self.function_call_impl_inner_::<IS_CNST>(f, args, generics, f_path.clone(), f_name.clone())
.map(|v| {
self.fn_call_memoization.borrow_mut().insert(input, v.clone());
v
})
self.function_call_impl_inner_::<IS_CNST>(
f,
args,
generics,
f_path.clone(),
f_name.clone(),
)
.inspect(|v| {
self.fn_call_memoization
.borrow_mut()
.insert(input, v.clone());
})
};
let dur = (time::Instant::now() - before).as_millis();
if dur > 50 {
@@ -779,7 +789,7 @@ impl<'ast> ZGen<'ast> {
&self,
f: &ast::FunctionDefinition<'ast>,
args: Vec<T>,
generics: HashMap<String,T>,
generics: HashMap<String, T>,
f_path: PathBuf,
f_name: String,
) -> Result<T, String> {
@@ -1246,16 +1256,18 @@ impl<'ast> ZGen<'ast> {
}
}
fn expr_impl_<const IS_CNST: bool>(&self, e: &ast::Expression<'ast>) -> Result<T,String> {
fn expr_impl_<const IS_CNST: bool>(&self, e: &ast::Expression<'ast>) -> Result<T, String> {
self.expr_impl_inner_::<IS_CNST>(e)
.map(const_fold)
.and_then(|v| if IS_CNST {const_val_simple(v)} else {Ok(v)})
.and_then(|v| if IS_CNST { const_val_simple(v) } else { Ok(v) })
.map_err(|err| format!("{}; context:\n{}", err, span_to_string(e.span())))
}
// XXX(rsw) make Result<T, (String, Span)> to give more precise error messages?
fn expr_impl_inner_<const IS_CNST: bool>(&self, e: &ast::Expression<'ast>) -> Result<T, String> {
fn expr_impl_inner_<const IS_CNST: bool>(
&self,
e: &ast::Expression<'ast>,
) -> Result<T, String> {
if IS_CNST {
debug!("Const expr: {}", e.span().as_str());
} else {
@@ -1264,7 +1276,11 @@ impl<'ast> ZGen<'ast> {
match e {
ast::Expression::Ternary(u) => {
match self.expr_impl_::<false>(&u.first).ok().and_then(const_bool_simple) {
match self
.expr_impl_::<IS_CNST>(&u.first)
.ok()
.and_then(const_bool_simple)
{
Some(true) => self.expr_impl_::<IS_CNST>(&u.second),
Some(false) => self.expr_impl_::<IS_CNST>(&u.third),
None if IS_CNST => Err("ternary condition not const bool".to_string()),
@@ -1436,8 +1452,8 @@ impl<'ast> ZGen<'ast> {
.map_err(|e| format!("{e}"))
}
ast::Statement::Assertion(e) => {
let expr = self.expr_impl_::<false>(&e.expression);
match expr.clone().ok().and_then(const_bool_simple) {
let expr = self.expr_impl_::<IS_CNST>(&e.expression)?;
match const_bool_simple(expr.clone()) {
Some(true) => Ok(()),
Some(false) => Err(format!(
"Const assert failed: {} at\n{}",
@@ -1448,11 +1464,11 @@ impl<'ast> ZGen<'ast> {
span_to_string(e.expression.span()),
)),
None if IS_CNST => Err(format!(
"Const assert expression eval failed at\n{}",
"Const assert failed (non-const expression) at\n{}",
span_to_string(e.expression.span()),
)),
_ => {
let b = bool(expr?)?;
let b = bool(expr)?;
self.assert(b)?;
Ok(())
}

View File

@@ -226,8 +226,8 @@ impl T {
}
pub fn new_integer<I>(v: I) -> Self
where
Integer: From<I>
where
Integer: From<I>,
{
T::new(Ty::Integer, int_lit(v))
}
@@ -335,7 +335,7 @@ fn wrap_bin_pred(
b: T,
) -> Result<T, String> {
match (&a.ty, &b.ty, fu, ff, fb, fi) {
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _,_) if na == nb => {
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _, _) if na == nb => {
Ok(T::new(Ty::Bool, fu(a.term.clone(), b.term.clone())))
}
(Ty::Bool, Ty::Bool, _, _, Some(fb), _) => {
@@ -364,7 +364,15 @@ fn add_integer(a: Term, b: Term) -> Term {
}
pub fn add(a: T, b: T) -> Result<T, String> {
wrap_bin_op("+", Some(add_uint), Some(add_field), None, Some(add_integer), a, b)
wrap_bin_op(
"+",
Some(add_uint),
Some(add_field),
None,
Some(add_integer),
a,
b,
)
}
fn sub_uint(a: Term, b: Term) -> Term {
@@ -380,7 +388,15 @@ fn sub_integer(a: Term, b: Term) -> Term {
}
pub fn sub(a: T, b: T) -> Result<T, String> {
wrap_bin_op("-", Some(sub_uint), Some(sub_field), None, Some(sub_integer), a, b)
wrap_bin_op(
"-",
Some(sub_uint),
Some(sub_field),
None,
Some(sub_integer),
a,
b,
)
}
fn mul_uint(a: Term, b: Term) -> Term {
@@ -396,7 +412,15 @@ fn mul_integer(a: Term, b: Term) -> Term {
}
pub fn mul(a: T, b: T) -> Result<T, String> {
wrap_bin_op("*", Some(mul_uint), Some(mul_field), None, Some(mul_integer), a, b)
wrap_bin_op(
"*",
Some(mul_uint),
Some(mul_field),
None,
Some(mul_integer),
a,
b,
)
}
fn div_uint(a: Term, b: Term) -> Term {
@@ -412,7 +436,15 @@ fn div_integer(a: Term, b: Term) -> Term {
}
pub fn div(a: T, b: T) -> Result<T, String> {
wrap_bin_op("/", Some(div_uint), Some(div_field), None, Some(div_integer), a, b)
wrap_bin_op(
"/",
Some(div_uint),
Some(div_field),
None,
Some(div_integer),
a,
b,
)
}
fn to_dflt_f(t: Term) -> Term {
@@ -435,7 +467,15 @@ fn rem_integer(a: Term, b: Term) -> Term {
}
pub fn rem(a: T, b: T) -> Result<T, String> {
wrap_bin_op("%", Some(rem_uint), Some(rem_field), None, Some(rem_integer), a, b)
wrap_bin_op(
"%",
Some(rem_uint),
Some(rem_field),
None,
Some(rem_integer),
a,
b,
)
}
fn bitand_uint(a: Term, b: Term) -> Term {
@@ -515,12 +555,20 @@ fn ult_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Ult)
}
fn ult_integer(a: Term, b:Term) -> Term {
fn ult_integer(a: Term, b: Term) -> Term {
term![Op::IntBinPred(IntBinPred::Lt); a,b]
}
pub fn ult(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("<", Some(ult_uint), Some(ult_field), None, Some(ult_integer), a, b)
wrap_bin_pred(
"<",
Some(ult_uint),
Some(ult_field),
None,
Some(ult_integer),
a,
b,
)
}
fn ule_uint(a: Term, b: Term) -> Term {
@@ -531,12 +579,20 @@ fn ule_field(a: Term, b: Term) -> Term {
field_comp(a, b, BvBinPred::Ule)
}
fn ule_integer(a: Term, b:Term) -> Term {
fn ule_integer(a: Term, b: Term) -> Term {
term![Op::IntBinPred(IntBinPred::Le); a, b]
}
pub fn ule(a: T, b: T) -> Result<T, String> {
wrap_bin_pred("<=", Some(ule_uint), Some(ule_field), None, Some(ule_integer), a, b)
wrap_bin_pred(
"<=",
Some(ule_uint),
Some(ule_field),
None,
Some(ule_integer),
a,
b,
)
}
fn ugt_uint(a: Term, b: Term) -> Term {
@@ -552,7 +608,15 @@ fn ugt_integer(a: Term, b: Term) -> Term {
}
pub fn ugt(a: T, b: T) -> Result<T, String> {
wrap_bin_pred(">", Some(ugt_uint), Some(ugt_field), None, Some(ugt_integer), a, b)
wrap_bin_pred(
">",
Some(ugt_uint),
Some(ugt_field),
None,
Some(ugt_integer),
a,
b,
)
}
fn uge_uint(a: Term, b: Term) -> Term {
@@ -568,18 +632,31 @@ fn uge_integer(a: Term, b: Term) -> Term {
}
pub fn uge(a: T, b: T) -> Result<T, String> {
wrap_bin_pred(">=", Some(uge_uint), Some(uge_field), None, Some(uge_integer), a, b)
wrap_bin_pred(
">=",
Some(uge_uint),
Some(uge_field),
None,
Some(uge_integer),
a,
b,
)
}
pub fn pow(a: T, b: T) -> Result<T, String> {
if (a.ty != Ty::Field && a.ty != Ty::Integer) || b.ty != Ty::Uint(32) {
return Err(format!("Cannot compute {a} ** {b} : must be Field/Integer ** U32"));
return Err(format!(
"Cannot compute {a} ** {b} : must be Field/Integer ** U32"
));
}
let b = const_int(b)?;
if b == 0 {
return Ok((if a.ty == Ty::Field {T::new_field} else {T::new_integer})(1))
return Ok((if a.ty == Ty::Field {
T::new_field
} else {
T::new_integer
})(1));
}
Ok((0..b.significant_bits() - 1)
@@ -625,7 +702,14 @@ fn neg_integer(a: Term) -> Term {
// Missing from ZoKrates.
pub fn neg(a: T) -> Result<T, String> {
wrap_un_op("unary-", Some(neg_uint), Some(neg_field), None, Some(neg_integer), a)
wrap_un_op(
"unary-",
Some(neg_uint),
Some(neg_field),
None,
Some(neg_integer),
a,
)
}
fn not_bool(a: Term) -> Term {
@@ -658,7 +742,7 @@ pub fn const_bool(a: T) -> Option<bool> {
pub fn const_fold(t: T) -> T {
let folded = constant_fold(&t.term, &[]);
return T::new(t.ty, folded)
T::new(t.ty, folded)
}
pub fn const_val(a: T) -> Result<T, String> {
@@ -743,7 +827,6 @@ where
T::new(Ty::Uint(bits), bv_lit(v, bits))
}
pub fn slice(arr: T, start: Option<usize>, end: Option<usize>) -> Result<T, String> {
match &arr.ty {
Ty::Array(size, _) => {
@@ -893,7 +976,10 @@ pub fn uint_to_field(u: T) -> Result<T, String> {
pub fn integer_to_field(u: T) -> Result<T, String> {
match &u.ty {
Ty::Integer => Ok(T::new(Ty::Field, term![Op::IntToPf(default_field()); u.term])),
Ty::Integer => Ok(T::new(
Ty::Field,
term![Op::IntToPf(default_field()); u.term],
)),
u => Err(format!("Cannot do int-to-field on {u}")),
}
}
@@ -905,8 +991,7 @@ pub fn field_to_integer(u: T) -> Result<T, String> {
}
}
pub fn int_to_bits(i: T, n: usize) -> Result<T,String> {
pub fn int_to_bits(i: T, n: usize) -> Result<T, String> {
match &i.ty {
Ty::Integer => uint_to_bits(T::new(Ty::Uint(n), term![Op::IntToBv(n); i.term])),
u => Err(format!("Cannot do uint-to-bits on {u}")),
@@ -922,7 +1007,10 @@ pub fn int_size(i: T) -> Result<T, String> {
pub fn int_modinv(i: T, m: T) -> Result<T, String> {
match (&i.ty, &m.ty) {
(Ty::Integer, Ty::Integer) => Ok(T::new(Ty::Integer, term![Op::IntBinOp(IntBinOp::ModInv); i.term, m.term])),
(Ty::Integer, Ty::Integer) => Ok(T::new(
Ty::Integer,
term![Op::IntBinOp(IntBinOp::ModInv); i.term, m.term],
)),
u => Err(format!("Cannot do modinv on {:?}", u)),
}
}

View File

@@ -414,7 +414,9 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> {
},
Pow => match &bt {
// XXX does POW operator really require U32 RHS?
Field(_) | Integer(_) => Ok((Basic(bt), Basic(U32(ast::U32Type { span: be.span })))),
Field(_) | Integer(_) => {
Ok((Basic(bt), Basic(U32(ast::U32Type { span: be.span }))))
}
_ => Err(ZVisitorError(
"ZStatementWalker: pow operator must take Field LHS and U32 RHS".to_owned(),
)),

View File

@@ -204,7 +204,9 @@ impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk>
DS::Field(s) => self
.ty
.replace(Basic(Field(ast::FieldType { span: s.span }))),
DS::Integer(s) => self.ty.replace(Basic(Integer(ast::IntegerType {span: s.span }))),
DS::Integer(s) => self
.ty
.replace(Basic(Integer(ast::IntegerType { span: s.span }))),
};
Ok(())
}

View File

@@ -363,34 +363,26 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T
(Some(arr), Some(idx)) => Some(const_(arr.select(idx))),
_ => None,
},
Op::Tuple => {
Some(new_tuple(t
.cs()
.iter()
.map(c_get).collect::<Vec<_>>()))
},
Op::Field(n) => {
match get(0).op() {
Op::Tuple => {
let term = get(0).cs()[*n].clone();
Some(term.as_value_opt()
Op::Tuple => Some(new_tuple(t.cs().iter().map(c_get).collect::<Vec<_>>())),
Op::Field(n) => match get(0).op() {
Op::Tuple => {
let term = get(0).cs()[*n].clone();
Some(
term.as_value_opt()
.cloned()
.map(|t| leaf_term(Op::new_const(t)))
.unwrap_or(term))
}
_ => None
.unwrap_or(term),
)
}
_ => None,
},
Op::Update(n) => {
match get(0).op() {
Op::Tuple => {
let mut children = get(0).cs().to_vec();
children[*n] = get(1).clone();
Some(new_tuple(children))
}
_ => None
Op::Update(n) => match get(0).op() {
Op::Tuple => {
let mut children = get(0).cs().to_vec();
children[*n] = get(1).clone();
Some(new_tuple(children))
}
_ => None,
},
Op::BvConcat => t
.cs()

View File

@@ -191,11 +191,43 @@ pub fn dump_op_stats() {
}
}
/// An iterator over a term's children.
/// It goes in reverse order.
pub struct TermCsRevIter {
term: Term,
last_cs: usize,
}
impl TermCsRevIter {
/// Create an iterator over this term's children.
pub fn new(term: Term) -> Self {
Self {
last_cs: term.cs().len(),
term,
}
}
/// Get the term.
pub fn term(self) -> Term {
self.term
}
}
impl Iterator for TermCsRevIter {
type Item = Term;
fn next(&mut self) -> Option<Self::Item> {
if self.last_cs > 0 {
self.last_cs -= 1;
Some(self.term.cs()[self.last_cs].clone())
} else {
None
}
}
}
/// Iterator over descendents in child-first order.
pub struct PostOrderSkipIter<'a, F: Fn(&Term) -> bool + 'a> {
// (cs stacked, term)
stack: Vec<(bool, Term)>,
visited: TermSet,
stack: Vec<TermCsRevIter>,
outputed: TermSet,
skip_if: &'a F,
}
@@ -203,8 +235,8 @@ impl<'a, F: Fn(&Term) -> bool + 'a> PostOrderSkipIter<'a, F> {
/// Make an iterator over the descendents of `root`.
pub fn new(root: Term, skip_if: &'a F) -> Self {
Self {
stack: vec![(false, root)],
visited: TermSet::default(),
stack: vec![TermCsRevIter::new(root)],
outputed: TermSet::default(),
skip_if,
}
}
@@ -213,22 +245,27 @@ impl<'a, F: Fn(&Term) -> bool + 'a> PostOrderSkipIter<'a, F> {
impl<'a, F: Fn(&Term) -> bool + 'a> std::iter::Iterator for PostOrderSkipIter<'a, F> {
type Item = Term;
fn next(&mut self) -> Option<Term> {
while let Some((children_pushed, t)) = self.stack.last() {
if self.visited.contains(t) || (self.skip_if)(t) {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
let last = self.stack.last().unwrap().1.clone();
self.stack
.extend(last.cs().iter().map(|c| (false, c.clone())));
} else {
break;
#[allow(clippy::while_let_on_iterator)]
while let Some(iter) = self.stack.last_mut() {
let mut empty = true;
while let Some(n) = iter.next() {
if !self.outputed.contains(&n) && !(self.skip_if)(&n) {
self.stack.push(TermCsRevIter::new(n));
empty = false;
break;
}
}
if empty {
let term = self.stack.pop().unwrap().term;
if !(self.skip_if)(&term) {
// If it is newly inserted
if self.outputed.insert(term.clone()) {
return Some(term);
}
}
}
}
self.stack.pop().map(|(_, t)| {
self.visited.insert(t.clone());
t
})
None
}
}
@@ -252,6 +289,92 @@ pub fn node_cs_iter(node: Term) -> impl Iterator<Item = Term> {
(0..node.cs().len()).map(move |i| node.cs()[i].clone())
}
#[cfg(test)]
mod test {
use super::*;
use crate::ir::term::dist::test::*;
use itertools::Itertools;
use quickcheck_macros::quickcheck;
/// Iterator over descendents in child-first order.
pub struct PostOrderSkipIterNaive<'a, F: Fn(&Term) -> bool + 'a> {
// (cs stacked, term)
stack: Vec<(bool, Term)>,
visited: TermSet,
skip_if: &'a F,
}
impl<'a, F: Fn(&Term) -> bool + 'a> PostOrderSkipIterNaive<'a, F> {
/// Make an iterator over the descendents of `root`.
pub fn new(root: Term, skip_if: &'a F) -> Self {
Self {
stack: vec![(false, root)],
visited: TermSet::default(),
skip_if,
}
}
}
impl<'a, F: Fn(&Term) -> bool + 'a> std::iter::Iterator for PostOrderSkipIterNaive<'a, F> {
type Item = Term;
fn next(&mut self) -> Option<Term> {
while let Some((children_pushed, t)) = self.stack.last() {
if self.visited.contains(t) || (self.skip_if)(t) {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
let last = self.stack.last().unwrap().1.clone();
self.stack
.extend(last.cs().iter().map(|c| (false, c.clone())));
} else {
break;
}
}
self.stack.pop().map(|(_, t)| {
self.visited.insert(t.clone());
t
})
}
}
fn traversal_crosscheck_skipif(t: Term, skips: Vec<Term>) {
let skip_set: TermSet = skips.into_iter().collect();
let skip_if = |t: &Term| skip_set.contains(t);
let list1 = PostOrderSkipIter::new(t.clone(), &skip_if).collect_vec();
let list2 = PostOrderSkipIterNaive::new(t.clone(), &skip_if).collect_vec();
assert_eq!(list1, list2, "term: {}\nskips: {:#?}", t, skip_set);
}
#[test]
fn traverse_const() {
traversal_crosscheck_skipif(bool_lit(true), vec![]);
}
#[test]
fn traverse_const_empty() {
traversal_crosscheck_skipif(bool_lit(true), vec![bool_lit(true)]);
}
#[test]
fn traverse_bool() {
traversal_crosscheck_skipif(
text::parse_term(
b"
(declare (
(a bool)
(b bool)
) (and a b a (or a b a true) false (and a b)))
",
),
vec![bool_lit(true)],
);
}
#[quickcheck]
fn random_bool_opt(ArbitraryTermEnv(t, _values): ArbitraryTermEnv) {
traversal_crosscheck_skipif(t, vec![]);
}
}
// impl ChildrenIter {
// fn new(node: Term) -> Self {
// Self {node, next_child: 0}

View File

@@ -1456,10 +1456,7 @@ impl Value {
/// Is this value a scalar (non-composite) type?
pub fn is_scalar(&self) -> bool {
match self {
Value::Array(..) | Value::Map(..) | Value::Tuple(..) => false,
_ => true,
}
!matches!(self, Value::Array(..) | Value::Map(..) | Value::Tuple(..))
}
}
@@ -1611,17 +1608,16 @@ pub(super) const TERM_CACHE_LIMIT: usize = 65536;
/// Iterator over descendents in child-first order.
pub struct PostOrderIter {
// (cs stacked, term)
stack: Vec<(bool, Term)>,
visited: TermSet,
stack: Vec<extras::TermCsRevIter>,
outputed: TermSet,
}
impl PostOrderIter {
/// Make an iterator over the descendents of `root`.
pub fn new(root: Term) -> Self {
Self {
stack: vec![(false, root)],
visited: TermSet::default(),
stack: vec![extras::TermCsRevIter::new(root)],
outputed: TermSet::default(),
}
}
/// Make an iterator over the descendents of `roots`, stopping at `skips`.
@@ -1630,9 +1626,9 @@ impl PostOrderIter {
stack: roots
.into_iter()
.filter(|t| !skips.contains(t))
.map(|t| (false, t))
.map(extras::TermCsRevIter::new)
.collect(),
visited: skips,
outputed: skips,
}
}
}
@@ -1640,22 +1636,25 @@ impl PostOrderIter {
impl std::iter::Iterator for PostOrderIter {
type Item = Term;
fn next(&mut self) -> Option<Term> {
while let Some((children_pushed, t)) = self.stack.last() {
if self.visited.contains(t) {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
let last = self.stack.last().unwrap().1.clone();
self.stack
.extend(last.cs().iter().map(|c| (false, c.clone())));
} else {
break;
#[allow(clippy::while_let_on_iterator)]
while let Some(iter) = self.stack.last_mut() {
let mut empty = true;
while let Some(n) = iter.next() {
if !self.outputed.contains(&n) {
self.stack.push(extras::TermCsRevIter::new(n));
empty = false;
break;
}
}
if empty {
let term = self.stack.pop().unwrap().term();
// If it is newly inserted
if self.outputed.insert(term.clone()) {
return Some(term);
}
}
}
self.stack.pop().map(|(_, t)| {
self.visited.insert(t.clone());
t
})
None
}
}

View File

@@ -443,12 +443,12 @@ impl R1cs {
let s = &mut self.stats;
s.n_constraints = self.constraints.len() as u32;
for (a, b, c) in &self.constraints {
let n_a = a.monomials.len() + !a.constant.is_zero() as usize;
let n_b = b.monomials.len() + !b.constant.is_zero() as usize;
let n_c = c.monomials.len() + !c.constant.is_zero() as usize;
s.n_a_entries += n_a as u32;
s.n_b_entries += n_b as u32;
s.n_c_entries += n_c as u32;
let n_a = a.monomials.len() as u32 + !a.constant.is_zero() as u32;
let n_b = b.monomials.len() as u32 + !b.constant.is_zero() as u32;
let n_c = c.monomials.len() as u32 + !c.constant.is_zero() as u32;
s.n_a_entries += n_a;
s.n_b_entries += n_b;
s.n_c_entries += n_c;
}
}
}

View File

@@ -29,9 +29,8 @@ impl<T: Eq + Hash + Clone> OnceQueue<T> {
}
/// Remove the oldest element from the queue.
pub fn pop(&mut self) -> Option<T> {
self.queue.pop_front().map(|t| {
self.set.remove(&t);
t
self.queue.pop_front().inspect(|t| {
self.set.remove(t);
})
}
/// Make an empty queue.

View File

@@ -16,14 +16,15 @@ pub use ast::{
Expression, FieldSuffix, FieldType, File, FromExpression, FromImportDirective,
FunctionDefinition, HexLiteralExpression, HexNumberExpression, IdentifierExpression,
ImportDirective, ImportSymbol, InlineArrayExpression, InlineStructExpression,
InlineStructMember, IterationStatement, LiteralExpression, MainImportDirective, MemberAccess,
NegOperator, NotOperator, Parameter, PosOperator, PostfixExpression, Pragma, PrivateNumber,
PrivateVisibility, PublicVisibility, Range, RangeOrExpression, ReturnStatement, Span, Spread,
SpreadOrExpression, Statement, StrOperator, StructDefinition, StructField, StructType,
SymbolDeclaration, TernaryExpression, ToExpression, Type, TypeDefinition, TypedIdentifier,
TypedIdentifierOrAssignee, U16NumberExpression, U16Suffix, U16Type, U32NumberExpression,
U32Suffix, U32Type, IntegerSuffix, IntegerType, U64NumberExpression, U64Suffix, U64Type, U8NumberExpression, U8Suffix,
U8Type, UnaryExpression, UnaryOperator, Underscore, Visibility, WitnessStatement, EOI,
InlineStructMember, IntegerSuffix, IntegerType, IterationStatement, LiteralExpression,
MainImportDirective, MemberAccess, NegOperator, NotOperator, Parameter, PosOperator,
PostfixExpression, Pragma, PrivateNumber, PrivateVisibility, PublicVisibility, Range,
RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StrOperator,
StructDefinition, StructField, StructType, SymbolDeclaration, TernaryExpression, ToExpression,
Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee, U16NumberExpression,
U16Suffix, U16Type, U32NumberExpression, U32Suffix, U32Type, U64NumberExpression, U64Suffix,
U64Type, U8NumberExpression, U8Suffix, U8Type, UnaryExpression, UnaryOperator, Underscore,
Visibility, WitnessStatement, EOI,
};
mod ast {
@@ -271,7 +272,7 @@ mod ast {
U16(U16Type<'ast>),
U32(U32Type<'ast>),
U64(U64Type<'ast>),
Integer(IntegerType<'ast>)
Integer(IntegerType<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
@@ -955,7 +956,7 @@ mod ast {
U32(U32Suffix<'ast>),
U64(U64Suffix<'ast>),
Field(FieldSuffix<'ast>),
Integer(IntegerSuffix<'ast>)
Integer(IntegerSuffix<'ast>),
}
#[derive(Debug, FromPest, PartialEq, Clone)]