mirror of
https://github.com/circify/circ.git
synced 2026-01-09 13:48:02 -05:00
Eliminate tuples in preprocessing (#202)
This started as an optimization patch, but my first optimization revealed a bug. In chasing the bug, I found more optimization. Changes: 1. Eliminate tuples in preprocessing. (opt) 2. Handle CStore in tuple elimination pass. (bugfix) 3. Use tuples instead of arrays in a few more extension ops: (opt) * GCD for vanishing polynomials and their derivatives * sorting in transcript checking 4. A few logging revisions
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
|
||||
(let (
|
||||
(x #f6)
|
||||
(return #f6)
|
||||
(return #f0)
|
||||
) false ; ignored
|
||||
))
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ fn main() {
|
||||
Backend::Smt { .. } => Mode::Proof,
|
||||
};
|
||||
let language = determine_language(&options.frontend.language, &options.path);
|
||||
println!("Running frontend");
|
||||
let cs = match language {
|
||||
#[cfg(all(feature = "smt", feature = "zok"))]
|
||||
DeterminedLanguage::Zsharp => {
|
||||
@@ -233,6 +234,7 @@ fn main() {
|
||||
panic!("Missing feature: c");
|
||||
}
|
||||
};
|
||||
println!("Running IR optimizations");
|
||||
let cs = match mode {
|
||||
Mode::Opt => opt(
|
||||
cs,
|
||||
@@ -295,8 +297,7 @@ fn main() {
|
||||
opt(cs, opts)
|
||||
}
|
||||
};
|
||||
println!("Done with IR optimization");
|
||||
|
||||
println!("Running backend");
|
||||
match options.backend {
|
||||
#[cfg(feature = "r1cs")]
|
||||
Backend::R1cs {
|
||||
@@ -306,7 +307,6 @@ fn main() {
|
||||
proof_impl,
|
||||
..
|
||||
} => {
|
||||
println!("Converting to r1cs");
|
||||
let cs = cs.get("main");
|
||||
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
|
||||
let mut r1cs = to_r1cs(cs, cfg());
|
||||
@@ -314,7 +314,7 @@ fn main() {
|
||||
println!("R1CS stats: {:#?}", r1cs.stats());
|
||||
}
|
||||
|
||||
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
|
||||
println!("Running r1cs optimizations ");
|
||||
r1cs = reduce_linearities(r1cs, cfg());
|
||||
|
||||
println!("Final R1cs size: {}", r1cs.constraints().len());
|
||||
@@ -326,7 +326,7 @@ fn main() {
|
||||
ProofAction::Count => (),
|
||||
#[cfg(feature = "bellman")]
|
||||
ProofAction::Setup => {
|
||||
println!("Generating Parameters");
|
||||
println!("Running Setup");
|
||||
match proof_impl {
|
||||
ProofImpl::Groth16 => Bellman::<Bls12>::setup_fs(
|
||||
prover_data,
|
||||
@@ -348,7 +348,7 @@ fn main() {
|
||||
ProofAction::Setup => panic!("Missing feature: bellman"),
|
||||
#[cfg(feature = "bellman")]
|
||||
ProofAction::CpSetup => {
|
||||
println!("Generating Parameters");
|
||||
println!("Running CpSetup");
|
||||
match proof_impl {
|
||||
ProofImpl::Groth16 => panic!("Groth16 is not CP"),
|
||||
ProofImpl::Mirage => Mirage::<Bls12>::cp_setup_fs(
|
||||
|
||||
@@ -72,8 +72,6 @@ function pf_test_isolate {
|
||||
done
|
||||
}
|
||||
|
||||
pf_test 2024_05_24_benny_bug
|
||||
pf_test 2024_05_31_benny_bug
|
||||
r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
|
||||
@@ -110,4 +108,7 @@ pf_test var_idx_arr_str_arr_str
|
||||
pf_test mm
|
||||
pf_test unused_var
|
||||
|
||||
pf_test 2024_05_24_benny_bug
|
||||
pf_test 2024_05_31_benny_bug
|
||||
|
||||
scripts/zx_tests/run_tests.sh
|
||||
|
||||
@@ -213,9 +213,7 @@ fn range_check(
|
||||
debug_assert!(values.iter().all(|v| check(v) == f_sort));
|
||||
let mut ms_hash_inputs = values.clone();
|
||||
values.extend(f_sort.elems_iter().take(n));
|
||||
let sorted_term = unmake_array(
|
||||
term![Op::ExtOp(ExtOp::Sort); make_array(f_sort.clone(), f_sort.clone(), values.clone())],
|
||||
);
|
||||
let sorted_term = tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, values.clone())]);
|
||||
let sorted: Vec<Term> = sorted_term
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
@@ -302,10 +300,7 @@ fn derivative_gcd(
|
||||
let ns = ns.subspace("uniq");
|
||||
let fs = Sort::Field(f.clone());
|
||||
let pairs = term(
|
||||
Op::Array(Box::new(ArrayOp {
|
||||
key: fs.clone(),
|
||||
val: Sort::new_tuple(vec![fs.clone(), Sort::Bool]),
|
||||
})),
|
||||
Op::Tuple,
|
||||
values
|
||||
.clone()
|
||||
.into_iter()
|
||||
@@ -314,8 +309,8 @@ fn derivative_gcd(
|
||||
.collect(),
|
||||
);
|
||||
let two_polys = term![Op::ExtOp(ExtOp::UniqDeriGcd); pairs];
|
||||
let s_coeffs = unmake_array(term![Op::Field(0); two_polys.clone()]);
|
||||
let t_coeffs = unmake_array(term![Op::Field(1); two_polys]);
|
||||
let s_coeffs = tuple_terms(term![Op::Field(0); two_polys.clone()]);
|
||||
let t_coeffs = tuple_terms(term![Op::Field(1); two_polys]);
|
||||
let mut decl_poly = |coeffs: Vec<Term>, poly_name: &str| -> Vec<Term> {
|
||||
coeffs
|
||||
.into_iter()
|
||||
|
||||
@@ -11,19 +11,15 @@ pub(super) fn waksman(
|
||||
val_sort: &Sort,
|
||||
new_var: &mut impl FnMut(&str, Term) -> Term,
|
||||
) -> Vec<Access> {
|
||||
let f = &cfg.field;
|
||||
let f_s = Sort::Field(f.clone());
|
||||
// (1) sort the transcript
|
||||
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
|
||||
let switch_settings_tuple = term![Op::ExtOp(ExtOp::Waksman); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())];
|
||||
let n = check(&switch_settings_tuple).as_tuple().len();
|
||||
let mut switch_settings: VecDeque<Term> = (0..n)
|
||||
.map(|i| {
|
||||
new_var(
|
||||
&format!("sw{}", i),
|
||||
term![Op::Field(i); switch_settings_tuple.clone()],
|
||||
)
|
||||
})
|
||||
let switch_settings_tuple =
|
||||
term![Op::ExtOp(ExtOp::Waksman); term(Op::Tuple, field_tuples.clone())];
|
||||
let switch_settings_terms = tuple_terms(switch_settings_tuple);
|
||||
let mut switch_settings: VecDeque<Term> = switch_settings_terms
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| new_var(&format!("sw{}", i), t))
|
||||
.collect();
|
||||
|
||||
let sorted_field_tuple_values: Vec<Term> =
|
||||
@@ -69,12 +65,10 @@ pub(super) fn msh(
|
||||
assertions: &mut Vec<Term>,
|
||||
) -> Vec<Access> {
|
||||
let f = &cfg.field;
|
||||
let f_s = Sort::Field(f.clone());
|
||||
// (1) sort the transcript
|
||||
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
|
||||
let sorted_field_tuple_values: Vec<Term> = unmake_array(
|
||||
term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())],
|
||||
);
|
||||
let sorted_field_tuple_values: Vec<Term> =
|
||||
tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, field_tuples.clone())]);
|
||||
let mut sorted_accesses: Vec<Access> = sorted_field_tuple_values
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
|
||||
@@ -106,7 +106,7 @@ impl ArrayGraph {
|
||||
.collect();
|
||||
while let Some(top) = stack.pop() {
|
||||
if ram_terms.insert(top.clone()) {
|
||||
trace!("Maybe RAM: {}", top);
|
||||
trace!("Maybe RAM: {}", top.op());
|
||||
for p in ps.get(&top).unwrap() {
|
||||
if right_sort(p, field) {
|
||||
stack.push(p.clone());
|
||||
|
||||
@@ -143,6 +143,7 @@ impl TupleTree {
|
||||
}
|
||||
}
|
||||
}
|
||||
#[track_caller]
|
||||
fn unwrap_non_tuple(self) -> Term {
|
||||
match self {
|
||||
TupleTree::NonTuple(t) => t,
|
||||
@@ -245,9 +246,13 @@ fn tuple_free(t: Term) -> bool {
|
||||
/// Run the tuple elimination pass.
|
||||
pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
let mut lifted: TermMap<TupleTree> = TermMap::default();
|
||||
let terms =
|
||||
PostOrderIter::from_roots_and_skips(cs.outputs().iter().cloned(), Default::default());
|
||||
// .chain(cs.precomputes.outputs().values().cloned()),
|
||||
let terms = PostOrderIter::from_roots_and_skips(
|
||||
cs.outputs()
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(cs.precomputes.outputs().values().cloned()),
|
||||
Default::default(),
|
||||
);
|
||||
for t in terms {
|
||||
let mut cs: Vec<TupleTree> = t
|
||||
.cs()
|
||||
@@ -270,6 +275,14 @@ pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]);
|
||||
TupleTree::NonTuple(term(AND, eqs.collect()))
|
||||
}
|
||||
Op::CStore => {
|
||||
let c = cs.pop().unwrap().unwrap_non_tuple();
|
||||
let v = cs.pop().unwrap();
|
||||
let i = cs.pop().unwrap().unwrap_non_tuple();
|
||||
let a = cs.pop().unwrap();
|
||||
debug_assert!(cs.is_empty());
|
||||
a.bimap(|a, v| term![Op::CStore; a, i.clone(), v, c.clone()], &v)
|
||||
}
|
||||
Op::Store => {
|
||||
let v = cs.pop().unwrap();
|
||||
let i = cs.pop().unwrap().unwrap_non_tuple();
|
||||
@@ -321,11 +334,11 @@ pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
.into_iter()
|
||||
.flat_map(|o| lifted.get(&o).unwrap().clone().flatten())
|
||||
.collect();
|
||||
// let os = cs.precomputes.outputs().clone();
|
||||
// for (name, old_term) in os {
|
||||
// let new_term = lifted.get(&old_term).unwrap().clone().as_term();
|
||||
// cs.precomputes.change_output(&name, new_term);
|
||||
// }
|
||||
let os = cs.precomputes.outputs().clone();
|
||||
for (name, old_term) in os {
|
||||
let new_term = lifted.get(&old_term).unwrap().clone().as_term();
|
||||
cs.precomputes.change_output(&name, new_term);
|
||||
}
|
||||
#[cfg(debug_assertions)]
|
||||
for o in &cs.outputs {
|
||||
if let Some(t) = find_tuple_term(o.clone()) {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
use crate::ir::term::ty::*;
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// Type-check [super::ExtOp::UniqDeriGcd].
|
||||
/// Type-check [super::ExtOp::Haboeck].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
let &[haystack, needles] = ty::count_or_ref(arg_sorts)?;
|
||||
let (_n, value0) = ty::homogenous_tuple_or(haystack, "haystack must be a tuple")?;
|
||||
@@ -21,7 +21,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
Ok(haystack.clone())
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::UniqDeriGcd].
|
||||
/// Evaluate [super::ExtOp::Haboeck].
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
let haystack: Vec<FieldV> = args[0]
|
||||
.as_tuple()
|
||||
|
||||
@@ -11,40 +11,24 @@ use crate::ir::term::*;
|
||||
|
||||
/// Type-check [super::ExtOp::UniqDeriGcd].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
if let &[pairs] = arg_sorts {
|
||||
let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?;
|
||||
let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?;
|
||||
let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?;
|
||||
if let &[root, cond] = &value_tup {
|
||||
eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?;
|
||||
eq_or(
|
||||
cond,
|
||||
&Sort::Bool,
|
||||
"UniqDeriGcd pairs: second element must be a bool",
|
||||
)?;
|
||||
let arr = Sort::new_array(f.clone(), f.clone(), size);
|
||||
Ok(Sort::new_tuple(vec![arr.clone(), arr]))
|
||||
} else {
|
||||
// non-pair entries value
|
||||
Err(TypeErrorReason::Custom(
|
||||
"UniqDeriGcd: pairs value must be a pair".into(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
// wrong arg count
|
||||
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
|
||||
}
|
||||
let [pairs] = ty::count_or_ref(arg_sorts)?;
|
||||
let (size, value) = ty::homogenous_tuple_or(pairs, "UniqDeriGcd")?;
|
||||
let [root, cond] = ty::count_or(ty::tuple_or(value, "UniqDeriGcd")?)?;
|
||||
let f = pf_or(root, "UniqDeriGcd: first is field")?;
|
||||
eq_or(cond, &Sort::Bool, "UniqDeriGcd pairs: second is bool")?;
|
||||
let coeffs = Sort::new_tuple(vec![f.clone(); size]);
|
||||
Ok(Sort::new_tuple(vec![coeffs.clone(), coeffs]))
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::UniqDeriGcd].
|
||||
#[cfg(feature = "poly")]
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
use rug_polynomial::ModPoly;
|
||||
let sort = args[0].sort().as_array().0.clone();
|
||||
let sort = args[0].sort().as_tuple()[0].as_tuple()[0].clone();
|
||||
let field = sort.as_pf().clone();
|
||||
let mut roots: Vec<Integer> = Vec::new();
|
||||
let deg = args[0].as_array().size;
|
||||
for t in args[0].as_array().values() {
|
||||
let deg = args[0].as_tuple().len();
|
||||
for t in args[0].as_tuple() {
|
||||
let tuple = t.as_tuple();
|
||||
let cond = tuple[1].as_bool();
|
||||
if cond {
|
||||
@@ -60,7 +44,7 @@ pub fn eval(args: &[&Value]) -> Value {
|
||||
let v: Vec<Value> = (0..deg)
|
||||
.map(|i| Value::Field(field.new_v(s.get_coefficient(i))))
|
||||
.collect();
|
||||
Value::Array(Array::from_vec(sort.clone(), sort.clone(), v))
|
||||
Value::Tuple(v.into())
|
||||
};
|
||||
let s_cs = coeff_arr(s);
|
||||
let t_cs = coeff_arr(t);
|
||||
|
||||
@@ -17,12 +17,21 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
/// Evaluate [super::ExtOp::Sort].
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
let sort = args[0].sort();
|
||||
let (key_sort, value_sort, _) = sort.as_array();
|
||||
let mut values: Vec<Value> = args[0].as_array().values();
|
||||
let is_array = sort.is_array();
|
||||
let mut values: Vec<Value> = if is_array {
|
||||
args[0].as_array().values()
|
||||
} else {
|
||||
args[0].as_tuple().to_vec()
|
||||
};
|
||||
values.sort();
|
||||
Value::Array(Array::from_vec(
|
||||
key_sort.clone(),
|
||||
value_sort.clone(),
|
||||
values,
|
||||
))
|
||||
if is_array {
|
||||
let (key_sort, value_sort, _) = sort.as_array();
|
||||
Value::Array(Array::from_vec(
|
||||
key_sort.clone(),
|
||||
value_sort.clone(),
|
||||
values,
|
||||
))
|
||||
} else {
|
||||
Value::Tuple(values.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ fn uniq_deri_gcd_eval() {
|
||||
let t = text::parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
|
||||
(pairs (tuple 5 (tuple (mod 17) bool)))
|
||||
)
|
||||
(uniq_deri_gcd pairs))",
|
||||
);
|
||||
@@ -35,7 +35,7 @@ fn uniq_deri_gcd_eval() {
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
|
||||
(pairs (#t (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) ))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
@@ -46,8 +46,8 @@ fn uniq_deri_gcd_eval() {
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage
|
||||
(#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage
|
||||
(#t #f16 #f0 #f0 #f0 #f0 ) ; s, from sage
|
||||
(#t #f7 #f9 #f0 #f0 #f0 ) ; t, from sage
|
||||
))
|
||||
) false))
|
||||
",
|
||||
@@ -59,7 +59,7 @@ fn uniq_deri_gcd_eval() {
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
|
||||
(pairs (#t (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true)))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
@@ -70,8 +70,8 @@ fn uniq_deri_gcd_eval() {
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage
|
||||
(#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage
|
||||
(#t #f8 #f9 #f16 #f0 #f0 ) ; s, from sage
|
||||
(#t #f2 #f16 #f9 #f13 #f0 ) ; t, from sage
|
||||
))
|
||||
) false))
|
||||
",
|
||||
|
||||
@@ -9,15 +9,16 @@ use std::iter::FromIterator;
|
||||
|
||||
/// Type-check [super::ExtOp::Waksman].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
array_or(arg_sorts[0], "sort argument")
|
||||
.map(|(_, _, n_flows)| Sort::Tuple(vec![Sort::Bool; n_switches(n_flows)].into()))
|
||||
let &[values] = ty::count_or_ref(arg_sorts)?;
|
||||
let (n_flows, _v_sort) = ty::homogenous_tuple_or(values, "Waksman argument")?;
|
||||
Ok(Sort::Tuple(vec![Sort::Bool; n_switches(n_flows)].into()))
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::Waksman].
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
let len = args[0].as_array().size;
|
||||
let cfg = Config::for_sorting(args[0].as_array().values());
|
||||
let values = args[0].as_tuple();
|
||||
let cfg = Config::for_sorting(values.to_vec());
|
||||
let switch_bools = Vec::from_iter(cfg.switches().into_iter().map(Value::Bool));
|
||||
assert_eq!(switch_bools.len(), n_switches(len));
|
||||
assert_eq!(switch_bools.len(), n_switches(values.len()));
|
||||
Value::Tuple(switch_bools.into())
|
||||
}
|
||||
|
||||
@@ -1381,7 +1381,7 @@ mod test {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
|
||||
(pairs (tuple 5 (tuple (mod 17) bool)))
|
||||
)
|
||||
(uniq_deri_gcd pairs))",
|
||||
);
|
||||
|
||||
@@ -463,6 +463,7 @@ where
|
||||
) -> Self::Proof {
|
||||
assert_eq!(rand.len(), pk.data.num_commitments());
|
||||
let rng = &mut rand::thread_rng();
|
||||
#[cfg(debug_assertions)]
|
||||
pk.data.check_all(witness);
|
||||
let rands: Vec<E::Fr> = rand.iter().map(|r| r.0).collect();
|
||||
let mut rng = &mut rand::thread_rng();
|
||||
|
||||
@@ -215,7 +215,7 @@ impl<'a> StagedWitCompEvaluator<'a> {
|
||||
rows.sort_by_key(|t| t.1);
|
||||
println!("time,op,nanos,counts,nanos_per,arg_sorts");
|
||||
for (op, nanos, counts, nanos_per, arg_sorts) in &rows {
|
||||
println!("time,{op},{nanos},{counts},{nanos_per},{arg_sorts}");
|
||||
println!("time,{op},{nanos},{counts},{nanos_per},\"{arg_sorts}\"");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user