From 152d5ad53163cf9230ddc04d25e05c6eb48c6352 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Mon, 19 Aug 2024 14:51:01 -0400 Subject: [PATCH] Opt: memory: linear for [group] const values (#207) For memories with constant values that have sorts which are linear groups, there is a way to optimize linear-scan memory-checking. This patch implements that optimization. --- examples/ZoKrates/pf/const_linear_lookup.zok | 23 ++++++ examples/circ.rs | 5 ++ examples/opa_bench.rs | 2 + scripts/zokrates_test.zsh | 1 + src/circify/mod.rs | 2 +- src/ir/opt/chall.rs | 4 - src/ir/opt/mem/lin.rs | 46 +++++++++-- src/ir/opt/mod.rs | 11 ++- src/ir/term/mod.rs | 87 ++++++++++++++++++++ src/lib.rs | 1 + src/target/aby/trans.rs | 2 - src/target/r1cs/opt.rs | 2 +- src/target/r1cs/wit_comp.rs | 10 +++ src/target/smt/mod.rs | 20 ++--- 14 files changed, 189 insertions(+), 27 deletions(-) create mode 100644 examples/ZoKrates/pf/const_linear_lookup.zok diff --git a/examples/ZoKrates/pf/const_linear_lookup.zok b/examples/ZoKrates/pf/const_linear_lookup.zok new file mode 100644 index 00000000..d31c1b75 --- /dev/null +++ b/examples/ZoKrates/pf/const_linear_lookup.zok @@ -0,0 +1,23 @@ +struct T { + field v + field w + field x + field y + field z +} + +const T[9] TABLE = [ + T { v: 1, w: 12, x: 13, y: 14, z: 15 }, + T { v: 2, w: 22, x: 23, y: 24, z: 25 }, + T { v: 3, w: 32, x: 33, y: 34, z: 35 }, + T { v: 4, w: 42, x: 43, y: 44, z: 45 }, + T { v: 5, w: 52, x: 53, y: 54, z: 55 }, + T { v: 6, w: 62, x: 63, y: 64, z: 65 }, + T { v: 7, w: 72, x: 73, y: 74, z: 75 }, + T { v: 8, w: 82, x: 83, y: 84, z: 85 }, + T { v: 9, w: 92, x: 93, y: 94, z: 95 } +] + +def main(field i) -> field: + T t = TABLE[i] + return t.v + t.w + t.x + t.y + t.z diff --git a/examples/circ.rs b/examples/circ.rs index e9ef7d58..d82e2ac8 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -326,6 +326,11 @@ fn main() { "Final R1cs rounds: {}", prover_data.precompute.stage_sizes().count() - 1 ); + println!( + "Final Witext steps: {}, arguments: {}", + prover_data.precompute.num_steps(), + prover_data.precompute.num_step_args() + ); match action { ProofAction::Count => (), #[cfg(feature = "bellman")] diff --git a/examples/opa_bench.rs b/examples/opa_bench.rs index c4fbd56c..d403eee4 100644 --- a/examples/opa_bench.rs +++ b/examples/opa_bench.rs @@ -1,3 +1,5 @@ +#![allow(clippy::mutable_key_type)] + use circ::cfg::clap::{self, Parser}; use circ::ir::term::*; use circ::target::aby::assignment::ilp; diff --git a/scripts/zokrates_test.zsh b/scripts/zokrates_test.zsh index 52aa6fa1..6f59b25d 100755 --- a/scripts/zokrates_test.zsh +++ b/scripts/zokrates_test.zsh @@ -73,6 +73,7 @@ function pf_test_isolate { } r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120 +r1cs_test_count ./examples/ZoKrates/pf/const_linear_lookup.zok 20 r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok diff --git a/src/circify/mod.rs b/src/circify/mod.rs index 19d9e09d..81dae53a 100644 --- a/src/circify/mod.rs +++ b/src/circify/mod.rs @@ -339,7 +339,7 @@ pub trait Embeddable { /// * `name`: the name /// * `visibility`: who knows it /// * `precompute`: an optional term for pre-computing the values of this input. If a party - /// knows the inputs to the precomputation, they can use the precomputation. + /// knows the inputs to the precomputation, they can use the precomputation. fn declare_input( &self, ctx: &mut CirCtx, diff --git a/src/ir/opt/chall.rs b/src/ir/opt/chall.rs index ca016809..2172326c 100644 --- a/src/ir/opt/chall.rs +++ b/src/ir/opt/chall.rs @@ -18,10 +18,6 @@ //! //! Each challenge term c that depends on t is replaced with a variable v. //! Let t' denote a rewritten term. -//! -//! Rules: -//! * round(v) >= -//! round(v use log::{debug, trace}; use std::cell::RefCell; diff --git a/src/ir/opt/mem/lin.rs b/src/ir/opt/mem/lin.rs index 51c332a5..e889c708 100644 --- a/src/ir/opt/mem/lin.rs +++ b/src/ir/opt/mem/lin.rs @@ -64,11 +64,47 @@ impl RewritePass for Linearizer { .unwrap_or_else(|| a.val.default_term()), ) } else { - let mut fields = (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]); - let first = fields.next().unwrap(); - Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| { - term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc] - })) + let value_sort = check(tup).as_tuple()[0].clone(); + if value_sort.is_group() { + // if values are a group + // then emit v0 + ite(idx == i1, v1 - v0, 0) + ... it(idx = iN, vN - v0, 0) + // where +, -, 0 are defined by the group. + // + // we do this because if the values are constant, then the above sum is + // linear, which is very nice for most backends. + let mut fields = + (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]); + let first = fields.next().unwrap(); + let zero = value_sort.group_identity(); + Some( + value_sort.group_add_nary( + std::iter::once(first.clone()) + .chain( + a.key + .elems_iter() + .take(a.size) + .skip(1) + .zip(fields) + .map(|(idx_c, field)| { + term![Op::Ite; + term![Op::Eq; idx.clone(), idx_c], + value_sort.group_sub(field, first.clone()), + zero.clone() + ] + }), + ) + .collect(), + ), + ) + } else { + // otherwise, ite(idx == iN, vN, ... ite(idx == i1, v1, v0) ... ) + let mut fields = + (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]); + let first = fields.next().unwrap(); + Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| { + term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc] + })) + } } } else { unreachable!() diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index 0f0a9868..9a3a81c9 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -65,7 +65,10 @@ pub enum Opt { pub fn opt>(mut cs: Computations, optimizations: I) -> Computations { for c in cs.comps.values() { trace!("Before all opts: {}", text::serialize_computation(c)); - info!("Before all opts: {} terms", c.stats().main.n_terms); + info!( + "Before all opts: {} terms", + c.stats().main.n_terms + c.stats().prec.n_terms + ); debug!("Before all opts: {:#?}", c.stats()); debug!("Before all opts: {:#?}", c.detailed_stats()); } @@ -167,7 +170,11 @@ pub fn opt>(mut cs: Computations, optimizations: I) fits_in_bits_ip::fits_in_bits_ip(c); } } - info!("After {:?}: {} terms", i, c.stats().main.n_terms); + info!( + "After {:?}: {} terms", + i, + c.stats().main.n_terms + c.stats().prec.n_terms + ); debug!("After {:?}: {:#?}", i, c.stats()); trace!("After {:?}: {}", i, text::serialize_computation(c)); #[cfg(debug_assertions)] diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 147807f9..084da824 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -1026,6 +1026,93 @@ impl Sort { pub fn is_scalar(&self) -> bool { !matches!(self, Sort::Tuple(..) | Sort::Array(..) | Sort::Map(..)) } + + /// Is this sort a group? + pub fn is_group(&self) -> bool { + match self { + Sort::BitVector(_) | Sort::Int | Sort::Field(_) | Sort::Bool => true, + Sort::F32 | Sort::F64 | Sort::Array(_) | Sort::Map(_) => false, + Sort::Tuple(fields) => fields.iter().all(|f| f.is_group()), + } + } + + /// The (n-ary) group operation for these terms. + pub fn group_add_nary(&self, ts: Vec) -> Term { + debug_assert!(ts.iter().all(|t| &check(t) == self)); + match self { + Sort::BitVector(_) => term(BV_ADD, ts), + Sort::Bool => term(XOR, ts), + Sort::Field(_) => term(PF_ADD, ts), + Sort::Int => term(INT_ADD, ts), + Sort::Tuple(sorts) => term( + Op::Tuple, + sorts + .iter() + .enumerate() + .map(|(i, sort)| { + sort.group_add_nary( + ts.iter() + .map(|t| term(Op::Field(i), vec![t.clone()])) + .collect(), + ) + }) + .collect(), + ), + _ => panic!("Not a group: {}", self), + } + } + + /// Group inverse + pub fn group_neg(&self, t: Term) -> Term { + debug_assert_eq!(&check(&t), self); + match self { + Sort::BitVector(_) => term(BV_NEG, vec![t]), + Sort::Bool => term(NOT, vec![t]), + Sort::Field(_) => term(PF_NEG, vec![t]), + Sort::Int => term( + INT_MUL, + vec![leaf_term(Op::new_const(Value::Int(Integer::from(-1i8)))), t], + ), + Sort::Tuple(sorts) => term( + Op::Tuple, + sorts + .iter() + .enumerate() + .map(|(i, sort)| sort.group_neg(term(Op::Field(i), vec![t.clone()]))) + .collect(), + ), + _ => panic!("Not a group: {}", self), + } + } + + /// Group identity + pub fn group_identity(&self) -> Term { + match self { + Sort::BitVector(n_bits) => bv_lit(0, *n_bits), + Sort::Bool => bool_lit(false), + Sort::Field(f) => pf_lit(f.new_v(0)), + Sort::Int => leaf_term(Op::new_const(Value::Int(Integer::from(0i8)))), + Sort::Tuple(sorts) => term( + Op::Tuple, + sorts.iter().map(|sort| sort.group_identity()).collect(), + ), + _ => panic!("Not a group: {}", self), + } + } + + /// Group operation + pub fn group_add(&self, s: Term, t: Term) -> Term { + debug_assert_eq!(&check(&s), self); + debug_assert_eq!(&check(&t), self); + self.group_add_nary(vec![s, t]) + } + + /// Group elimination + pub fn group_sub(&self, s: Term, t: Term) -> Term { + debug_assert_eq!(&check(&s), self); + debug_assert_eq!(&check(&t), self); + self.group_add(s, self.group_neg(t)) + } } mod hc { diff --git a/src/lib.rs b/src/lib.rs index ea51a008..851c264b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ #![warn(missing_docs)] #![deny(warnings)] #![allow(rustdoc::private_intra_doc_links)] +#![allow(clippy::mutable_key_type)] #[macro_use] pub mod ir; diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index fcc58500..26025f12 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -904,8 +904,6 @@ pub fn to_aby(cs: Computations, path: &Path, lang: &str, cm: &str, ss: &str) { panic!("Unsupported sharing scheme: {}", ss); } }; - #[cfg(feature = "bench")] - println!("LOG: Assignment {}: {:?}", name, now.elapsed()); s_map.insert(name.to_string(), assignments); } diff --git a/src/target/r1cs/opt.rs b/src/target/r1cs/opt.rs index aa24c5cb..8fd553e0 100644 --- a/src/target/r1cs/opt.rs +++ b/src/target/r1cs/opt.rs @@ -216,7 +216,7 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool { /// ## Parameters /// /// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used -/// for propagation. `None` means no size limit. +/// for propagation. `None` means no size limit. pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs { let mut r = LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run(); r.update_stats(); diff --git a/src/target/r1cs/wit_comp.rs b/src/target/r1cs/wit_comp.rs index 95a267ad..5e57b462 100644 --- a/src/target/r1cs/wit_comp.rs +++ b/src/target/r1cs/wit_comp.rs @@ -80,6 +80,16 @@ impl StagedWitComp { pub fn num_stage_inputs(&self, n: usize) -> usize { self.stages[n].inputs.len() } + + /// Number of steps + pub fn num_steps(&self) -> usize { + self.steps.len() + } + + /// Number of step arguments + pub fn num_step_args(&self) -> usize { + self.step_args.len() + } } /// Evaluator interface diff --git a/src/target/smt/mod.rs b/src/target/smt/mod.rs index 3ff4a658..07e2c1cb 100644 --- a/src/target/smt/mod.rs +++ b/src/target/smt/mod.rs @@ -365,9 +365,7 @@ pub fn check_sat(t: &Term) -> bool { let mut solver = make_solver((), false, false); for c in PostOrderIter::new(t.clone()) { if let Op::Var(v) = &c.op() { - solver - .declare_const(&SmtSymDisp(&*v.name), &v.sort) - .unwrap(); + solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap(); } } assert!(check(t) == Sort::Bool); @@ -380,9 +378,7 @@ fn get_model_solver(t: &Term, inc: bool) -> rsmt2::Solver { //solver.path_tee("solver_com").unwrap(); for c in PostOrderIter::new(t.clone()) { if let Op::Var(v) = &c.op() { - solver - .declare_const(&SmtSymDisp(&*v.name), &v.sort) - .unwrap(); + solver.declare_const(SmtSymDisp(&*v.name), &v.sort).unwrap(); } } assert!(check(t) == Sort::Bool); @@ -590,13 +586,13 @@ mod test { let mut solver = make_solver((), false, false); for (v, val) in vs { let s = val.sort(); - solver.declare_const(&SmtSymDisp(&v), &s).unwrap(); + solver.declare_const(SmtSymDisp(&v), &s).unwrap(); solver - .assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) + .assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) .unwrap(); } let val = eval(&t, vs); - solver.assert(&term![Op::Eq; t, const_(val)]).unwrap(); + solver.assert(term![Op::Eq; t, const_(val)]).unwrap(); solver.check_sat().unwrap() } @@ -605,14 +601,14 @@ mod test { let mut solver = make_solver((), false, false); for (v, val) in vs { let s = val.sort(); - solver.declare_const(&SmtSymDisp(&v), &s).unwrap(); + solver.declare_const(SmtSymDisp(&v), &s).unwrap(); solver - .assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) + .assert(term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) .unwrap(); } let val = eval(&t, vs); solver - .assert(&term![Op::Not; term![Op::Eq; t, const_(val)]]) + .assert(term![Op::Not; term![Op::Eq; t, const_(val)]]) .unwrap(); solver.check_sat().unwrap() }