From 8005aa2e373b7370ea2d00a30ea340948449326e Mon Sep 17 00:00:00 2001 From: Evan Laufer Date: Tue, 12 Jul 2022 15:52:39 -0700 Subject: [PATCH] Intermediary testing commit --- Cargo.lock | 24 ++- Cargo.toml | 7 +- examples/circ.rs | 32 +-- examples/zk.rs | 8 +- src/circify/mod.rs | 7 +- src/front/datalog/mod.rs | 8 +- src/front/datalog/term.rs | 3 +- src/front/mod.rs | 2 + src/front/zsharp/mod.rs | 45 ++-- src/front/zsharp/term.rs | 7 +- src/ir/opt/flat.rs | 54 +++-- src/ir/opt/mem/ram.rs | 196 +++++++++++------- src/ir/proof.rs | 4 +- src/ir/term/mod.rs | 129 +++++++++--- src/ir/term/text/mod.rs | 8 +- src/target/r1cs/marlin.rs | 24 ++- src/target/r1cs/trans.rs | 4 + .../zokrates_parser/src/zokrates.pest | 4 +- .../ZoKrates/zokrates_pest_ast/src/lib.rs | 37 ++++ 19 files changed, 426 insertions(+), 177 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e0e5002d..881ab302 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,18 +136,15 @@ dependencies = [ [[package]] name = "ark-marlin" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caa8510faa8e64f0a6841ee4b58efe2d56f7a80d86fa0ce9891bbb3aa20166d9" dependencies = [ "ark-ff", "ark-poly", "ark-poly-commit", - "ark-relations", + "ark-relations 0.3.0", "ark-serialize", "ark-std", "derivative", "digest 0.9.0", - "rand_chacha", "rayon", ] @@ -160,7 +157,7 @@ dependencies = [ "ark-ec", "ark-ff", "ark-r1cs-std", - "ark-relations", + "ark-relations 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "ark-std", "derivative", "num-bigint 0.4.3", @@ -193,7 +190,7 @@ dependencies = [ "ark-ff", "ark-nonnative-field", "ark-poly", - "ark-relations", + "ark-relations 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "ark-serialize", "ark-std", "derivative", @@ -210,7 +207,7 @@ checksum = "22e8fdacb1931f238a0d866ced1e916a49d36de832fd8b83dc916b718ae72893" dependencies = [ "ark-ec", "ark-ff", - "ark-relations", + "ark-relations 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "ark-std", "derivative", "num-bigint 0.4.3", @@ -218,6 +215,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "ark-relations" +version = "0.3.0" +dependencies = [ + "ark-ff", + "ark-std", + "tracing", + "tracing-subscriber", +] + [[package]] name = "ark-relations" version = "0.3.0" @@ -464,7 +471,7 @@ dependencies = [ "ark-marlin", "ark-poly", "ark-poly-commit", - "ark-relations", + "ark-relations 0.3.0", "ark-serialize", "bellman", "bincode", @@ -496,6 +503,7 @@ dependencies = [ "quickcheck", "quickcheck_macros", "rand", + "rand_chacha", "rsmt2", "rug", "serde", diff --git a/Cargo.toml b/Cargo.toml index a83edd9f..92a52d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,14 +42,15 @@ itertools = "0.10" petgraph = "0.6" paste = "1.0" im = "15" -ark-marlin = { version = "0.3.0", optional = true } -ark-relations = { version = "0.3.0", optional = true } +ark-marlin = { path = "../marlin", optional = true } +ark-relations = { path = "../snark/relations", optional = true } ark-ff = { version = "0.3.0", optional = true } ark-poly-commit = { version = "0.3.0", optional = true } ark-poly = { version = "0.3.0", optional = true } ark-serialize = { version = "0.3.0", optional = true } ark-bls12-381 = { version = "0.3.0", optional = true } sha2 = { version = "0.9.0", optional = true } +rand_chacha = { version = "0.3.1", optional = true } digest = { version = "0.9.0", optional = true } [dev-dependencies] @@ -69,7 +70,7 @@ lp = ["good_lp", "lp-solvers"] r1cs = ["bellman"] smt = ["rsmt2"] zok = ["zokrates_parser", "zokrates_pest_ast"] -marlin = ["ark-marlin", "ark-relations", "ark-ff", "ark-poly-commit", "ark-poly", "ark-serialize", "ark-bls12-381", "sha2", "digest"] +marlin = ["ark-marlin", "ark-relations", "ark-ff", "ark-poly-commit", "ark-poly", "ark-serialize", "ark-bls12-381", "sha2", "rand_chacha", "digest"] [[example]] name = "circ" diff --git a/examples/circ.rs b/examples/circ.rs index 905bac08..26b923dc 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -32,16 +32,18 @@ use circ::target::r1cs::bellman::{gen_params, prove, verify}; use circ::target::r1cs::opt::reduce_linearities; use circ::target::r1cs::trans::to_r1cs; -#[cfg(feature = "marlin")] -use circ::target::r1cs::marlin; #[cfg(feature = "marlin")] use ark_bls12_381::{Bls12_381, Fr as BlsFr}; #[cfg(feature = "marlin")] -use ark_poly_commit::marlin::marlin_pc::MarlinKZG10; +use ark_marlin::SimpleHashFiatShamirRng; #[cfg(feature = "marlin")] use ark_poly::univariate::DensePolynomial; #[cfg(feature = "marlin")] -use ark_marlin::rng::FiatShamirRng; +use ark_poly_commit::marlin::marlin_pc::MarlinKZG10; +#[cfg(feature = "marlin")] +use circ::target::r1cs::marlin; +#[cfg(feature = "marlin")] +use rand_chacha::ChaChaRng; #[cfg(feature = "marlin")] use sha2::Sha256; @@ -274,13 +276,14 @@ fn main() { Mode::Proof | Mode::ProofOfHighValue(_) => opt( cs, vec![ - Opt::RamExt, Opt::ScalarizeVars, Opt::Flatten, Opt::Sha, Opt::ConstantFold(Box::new([])), Opt::Flatten, Opt::Inline, + Opt::Tuple, + Opt::RamExt, // Tuples must be eliminated before oblivious array elim Opt::Tuple, Opt::ConstantFold(Box::new([])), @@ -328,21 +331,24 @@ fn main() { &verifier_data, ) .unwrap(); - }, + } #[cfg(feature = "marlin")] ProofSystem::Marlin => { - marlin::gen_params::>, Sha256, _, _>( - prover_key, - verifier_key, - &prover_data, - &verifier_data, + marlin::gen_params::< + BlsFr, + MarlinKZG10>, + SimpleHashFiatShamirRng, + _, + _, + >( + prover_key, verifier_key, &prover_data, &verifier_data ) .unwrap(); - }, + } #[cfg(not(feature = "marlin"))] ProofSystem::Marlin => { panic!("Missing feature: marlin"); - }, + } } } } diff --git a/examples/zk.rs b/examples/zk.rs index 2560f12e..98a067ab 100644 --- a/examples/zk.rs +++ b/examples/zk.rs @@ -14,9 +14,11 @@ use ark_poly_commit::marlin::marlin_pc::MarlinKZG10; #[cfg(feature = "marlin")] use ark_poly::univariate::DensePolynomial; #[cfg(feature = "marlin")] -use ark_marlin::rng::FiatShamirRng; +use ark_marlin::SimpleHashFiatShamirRng; #[cfg(feature = "marlin")] use sha2::Sha256; +#[cfg(feature = "marlin")] +use rand_chacha::ChaChaRng; #[derive(Debug, StructOpt)] #[structopt(name = "circ", about = "CirC: the circuit compiler")] @@ -67,7 +69,7 @@ fn main() { }, #[cfg(feature = "marlin")] ProofSystem::Marlin => { - marlin::prove::>, Sha256, _, _>(opts.prover_key, opts.proof, &input_map).unwrap(); + marlin::prove::>, SimpleHashFiatShamirRng, _, _>(opts.prover_key, opts.proof, &input_map).unwrap(); } #[cfg(not(feature = "marlin"))] ProofSystem::Marlin => { @@ -83,7 +85,7 @@ fn main() { }, #[cfg(feature = "marlin")] ProofSystem::Marlin => { - marlin::verify::>, Sha256, _, _>(opts.verifier_key, opts.proof, &input_map).unwrap(); + marlin::verify::>, SimpleHashFiatShamirRng, _, _>(opts.verifier_key, opts.proof, &input_map).unwrap(); } #[cfg(not(feature = "marlin"))] ProofSystem::Marlin => { diff --git a/src/circify/mod.rs b/src/circify/mod.rs index b3ec4252..bad086cf 100644 --- a/src/circify/mod.rs +++ b/src/circify/mod.rs @@ -354,6 +354,7 @@ pub trait Embeddable { ty: &Self::Ty, name: String, visibility: Option, + epoch: Epoch, precompute: Option, ) -> Self::T; @@ -473,6 +474,7 @@ impl Circify { nice_name: VarName, ty: &E::Ty, visibility: Option, + epoch: Epoch, precomputed_value: Option, mangle_name: bool, ) -> Result { @@ -484,7 +486,7 @@ impl Circify { }; let t = self .e - .declare_input(&mut self.cir_ctx, ty, name, visibility, precomputed_value); + .declare_input(&mut self.cir_ctx, ty, name, visibility, epoch, precomputed_value); assert!(self.vals.insert(ssa_name, Val::Term(t.clone())).is_none()); Ok(t) } @@ -916,6 +918,7 @@ mod test { ty: &Self::Ty, name: String, visibility: Option, + epoch: Epoch, precompute: Option, ) -> Self::T { match ty { @@ -940,6 +943,7 @@ mod test { &**a, format!("{}.0", name), visibility, + epoch, p_1, )), Box::new(self.declare_input( @@ -947,6 +951,7 @@ mod test { &**b, format!("{}.1", name), visibility, + epoch, p_2, )), ) diff --git a/src/front/datalog/mod.rs b/src/front/datalog/mod.rs index 490606b8..1c8e2104 100644 --- a/src/front/datalog/mod.rs +++ b/src/front/datalog/mod.rs @@ -126,7 +126,7 @@ impl<'ast> Gen<'ast> { let (ty, public) = self.ty(&d.ty); let vis = if public { PUBLIC_VIS } else { PROVER_VIS }; self.circ - .declare_input(d.ident.value.into(), &ty, vis, None, false)?; + .declare_input(d.ident.value.into(), &ty, vis, 0, None, false)?; } let r = self.rule_cases(rule)?; self.exit_function(name); @@ -146,7 +146,7 @@ impl<'ast> Gen<'ast> { for d in &decls.declarations { let (ty, _public) = self.ty(&d.ty); self.circ - .declare_input(d.ident.value.into(), &ty, PROVER_VIS, None, true)?; + .declare_input(d.ident.value.into(), &ty, PROVER_VIS, 0, None, true)?; } } c.exprs.iter().try_fold(term::bool_lit(true), |x, y| { @@ -316,7 +316,7 @@ impl<'ast> Gen<'ast> { let (ty, public) = self.ty(&d.ty); let vis = if public { PUBLIC_VIS } else { PROVER_VIS }; self.circ - .declare_input(d.ident.value.into(), &ty, vis, None, false)?; + .declare_input(d.ident.value.into(), &ty, vis, 0, None, false)?; } let mut bug_in_rule_if_any = Vec::new(); for cond in &rule.conds { @@ -327,7 +327,7 @@ impl<'ast> Gen<'ast> { for d in &decls.declarations { let (ty, _public) = self.ty(&d.ty); self.circ - .declare_input(d.ident.value.into(), &ty, None, None, true)?; + .declare_input(d.ident.value.into(), &ty, None, 0, None, true)?; } } let mut bad_recursion = Vec::new(); diff --git a/src/front/datalog/term.rs b/src/front/datalog/term.rs index a0a6fa3e..3d75172a 100644 --- a/src/front/datalog/term.rs +++ b/src/front/datalog/term.rs @@ -359,12 +359,13 @@ impl Embeddable for Datalog { ty: &Self::Ty, name: String, visibility: Option, + epoch: Epoch, precompute: Option, ) -> Self::T { T::new( ctx.cs .borrow_mut() - .new_var(&name, ty.sort(), visibility, precompute.map(|v| v.ir)), + .new_var(&name, ty.sort(), visibility, epoch, precompute.map(|v| v.ir)), ty.clone(), ) } diff --git a/src/front/mod.rs b/src/front/mod.rs index 3db25720..8c8341fc 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -13,6 +13,8 @@ use std::fmt::{self, Display, Formatter}; /// The prover visibility pub const PROVER_VIS: Option = Some(proof::PROVER_ID); +/// The verifier visibility +pub const VERIFIER_VIS: Option = Some(proof::VERIFIER_ID); /// Public visibility pub const PUBLIC_VIS: Option = None; diff --git a/src/front/zsharp/mod.rs b/src/front/zsharp/mod.rs index 12cf5776..fc2ecf5a 100644 --- a/src/front/zsharp/mod.rs +++ b/src/front/zsharp/mod.rs @@ -6,7 +6,7 @@ pub mod zvisit; use super::{FrontEnd, Mode}; use crate::circify::{CircError, Circify, Loc, Val}; -use crate::front::{PROVER_VIS, PUBLIC_VIS}; +use crate::front::{PROVER_VIS, VERIFIER_VIS, PUBLIC_VIS}; use crate::ir::proof::ConstraintMetadata; use crate::ir::term::*; use crate::util::field::DFL_T; @@ -679,8 +679,8 @@ impl<'ast> ZGen<'ast> { for p in f.parameters.iter() { let ty = self.type_(&p.ty); debug!("Entry param: {}: {}", p.id.value, ty); - let vis = self.interpret_visibility(&p.visibility); - let r = self.circ_declare_input(p.id.value.clone(), &ty, vis, None, false); + let (vis, epoch) = self.interpret_visibility(&p.visibility); + let r = self.circ_declare_input(p.id.value.clone(), &ty, vis, epoch, None, false); self.unwrap(r, &p.span); } for s in &f.statements { @@ -704,7 +704,7 @@ impl<'ast> ZGen<'ast> { let name = "return".to_owned(); let ret_val = r.unwrap_term(); let ret_var_val = self - .circ_declare_input(name, ty, PUBLIC_VIS, Some(ret_val.clone()), false) + .circ_declare_input(name, ty, PUBLIC_VIS, 0, Some(ret_val.clone()), false) .expect("circ_declare return"); let ret_eq = eq(ret_val, ret_var_val).unwrap().term; let mut assertions = std::mem::take(&mut *self.assertions.borrow_mut()); @@ -753,21 +753,33 @@ impl<'ast> ZGen<'ast> { } } } - fn interpret_visibility(&self, visibility: &Option>) -> Option { + fn interpret_visibility(&self, visibility: &Option>) -> (Option, Epoch) { match visibility { - None | Some(ast::Visibility::Public(_)) => PUBLIC_VIS, + None | Some(ast::Visibility::Public(_)) => (PUBLIC_VIS, 0), Some(ast::Visibility::Private(private)) => match self.mode { Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => { if private.number.is_some() { - self.err( - format!( - "Party number found, but we're generating a {} circuit", - self.mode - ), - &private.span, - ); + //self.err( + // format!( + // "Party number found, but we're generating a {} circuit", + // self.mode + // ), + // &private.span, + //); + let num_str = private.number.as_ref().unwrap(); + use std::convert::TryInto; + let num: u8 = num_str.try_into().unwrap_or_else(|e| self.err(e, &private.span)); + let epoch = private.epoch.as_ref().map(|e| e.try_into()).unwrap_or(Ok(0u8)).unwrap(); + if num == 0 { + (PROVER_VIS, epoch) + } else if num == 1 { + (VERIFIER_VIS, epoch) + } else { + self.err(format!("Bad party number: {}, can only be 0 or 1 for proofs!", num), &private.span); + } + } else { + (PROVER_VIS, 0) } - PROVER_VIS } Mode::Mpc(n_parties) => { let num_str = private @@ -780,7 +792,7 @@ impl<'ast> ZGen<'ast> { self.err(format!("Bad party number: {}", e), &private.span) }); if num_val <= n_parties { - Some(num_val - 1) + (Some(num_val - 1), 0) } else { self.err( format!( @@ -1858,12 +1870,13 @@ impl<'ast> ZGen<'ast> { name: String, ty: &Ty, vis: Option, + epoch: u8, precomputed_value: Option, mangle_name: bool, ) -> Result { self.circ .borrow_mut() - .declare_input(name, ty, vis, precomputed_value, mangle_name) + .declare_input(name, ty, vis, epoch, precomputed_value, mangle_name) } fn circ_declare_init(&self, name: String, ty: Ty, val: Val) -> Result, CircError> { diff --git a/src/front/zsharp/term.rs b/src/front/zsharp/term.rs index c8393dcd..8af83221 100644 --- a/src/front/zsharp/term.rs +++ b/src/front/zsharp/term.rs @@ -909,6 +909,7 @@ impl Embeddable for ZSharp { ty: &Self::Ty, name: String, visibility: Option, + epoch: Epoch, precompute: Option, ) -> Self::T { match ty { @@ -918,6 +919,7 @@ impl Embeddable for ZSharp { &name, Sort::Bool, visibility, + epoch, precompute.map(|p| p.term), ), ), @@ -927,6 +929,7 @@ impl Embeddable for ZSharp { &name, Sort::Field(DFL_T.clone()), visibility, + epoch, precompute.map(|p| p.term), ), ), @@ -936,6 +939,7 @@ impl Embeddable for ZSharp { &name, Sort::BitVector(*w), visibility, + epoch, precompute.map(|p| p.term), ), ), @@ -948,7 +952,7 @@ impl Embeddable for ZSharp { debug_assert_eq!(*n, ps.len()); array( ps.into_iter().enumerate().map(|(i, p)| { - self.declare_input(ctx, &*ty, idx_name(&name, i), visibility, p) + self.declare_input(ctx, &*ty, idx_name(&name, i), visibility, epoch, p) }), ) .unwrap() @@ -964,6 +968,7 @@ impl Embeddable for ZSharp { f_ty, field_name(&name, f_name), visibility, + epoch, precompute.as_ref().map(|_| unimplemented!("precomputations for declared inputs that are Z# structures")), ), ) diff --git a/src/ir/opt/flat.rs b/src/ir/opt/flat.rs index 5a018c9e..e5c7685b 100644 --- a/src/ir/opt/flat.rs +++ b/src/ir/opt/flat.rs @@ -3,6 +3,7 @@ use crate::ir::term::*; use std::rc::Rc; +#[derive(Clone)] enum Entry { Term(Rc), NaryTerm(Op, L, Option), @@ -74,9 +75,11 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache // what does a term rewrite to? let mut stack = vec![(term_.clone(), false)]; + let mut term_equals = TermSet::new(); + // Maps terms to their rewritten versions. while let Some((t, children_pushed)) = stack.pop() { - if rewritten.contains_key(&t) { + if rewritten.contains_key(&t) || term_equals.contains(&t) { continue; } if !children_pushed { @@ -89,18 +92,23 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache Op::BoolNaryOp(_) | Op::BvNaryOp(_) | Op::PfNaryOp(PfNaryOp::Add) => { let mut children = Vec::new(); for c in &t.cs { - match rewritten.get_mut(c).unwrap() { - Entry::Term(t) => { - children.push(Rc::new(PersistentConcatList::Leaf(t.clone()))) - } - Entry::NaryTerm(o, ts, _) - if &t.op == o && parent_counts.get(c).unwrap_or(&0) <= &1 => - { - children.push(ts.clone()); - } - e => { - children - .push(Rc::new(PersistentConcatList::Leaf(Rc::new(e.as_term())))); + if term_equals.contains(c) { + children.push(Rc::new(PersistentConcatList::Leaf(Rc::new(c.clone())))) + } else { + match rewritten.get_mut(c).unwrap() { + Entry::Term(t) => { + children.push(Rc::new(PersistentConcatList::Leaf(t.clone()))) + } + Entry::NaryTerm(o, ts, _) + if &t.op == o && parent_counts.get(c).unwrap_or(&0) <= &1 => + { + children.push(ts.clone()); + } + e => { + children.push(Rc::new(PersistentConcatList::Leaf(Rc::new( + e.as_term(), + )))); + } } } } @@ -113,13 +121,27 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache _ => Entry::Term(Rc::new(term( t.op.clone(), t.cs.iter() - .map(|c| rewritten.get_mut(c).unwrap().as_term()) + .map(|c| { + if term_equals.contains(c) { + c.clone() + } else { + rewritten.get_mut(c).unwrap().as_term() + } + }) .collect(), ))), }; - rewritten.insert(t, entry); + if t == entry.clone().as_term() { + term_equals.insert(t); + } else { + rewritten.insert(t, entry); + } + } + if term_equals.contains(&term_) { + term_ + } else { + rewritten.get_mut(&term_).unwrap().as_term() } - rewritten.get_mut(&term_).unwrap().as_term() } #[cfg(test)] diff --git a/src/ir/opt/mem/ram.rs b/src/ir/opt/mem/ram.rs index 8147525a..2b41de75 100644 --- a/src/ir/opt/mem/ram.rs +++ b/src/ir/opt/mem/ram.rs @@ -33,48 +33,75 @@ impl Access { is_write: guard, } } - fn universal_hash(&self, alpha: &Term, beta: &Term) -> Term { - let field = match check(alpha) { - Sort::Field(field) => field, - _ => panic!("Alpha value for universal hash isn't a field!"), - }; - assert_eq!( - check(beta), - Sort::Field(field.clone()), - "Beta value for universal hash isn't the same sort as alpha!" - ); +} - // universal hash of the value...handling tuples as necessary - let val_hash = match check(&self.val) { +fn multiset_hash(terms: impl IntoIterator, alpha: &Term) -> Term { + let field = match check(alpha) { + Sort::Field(field) => field, + _ => panic!("Alpha value for universal hash isn't a field!"), + }; + + let factors = terms + .into_iter() + .map(|t| { + // construct (alpha - t) for each term + assert_eq!( + check(&t), + Sort::Field(field.clone()), + "Term in multiset hash doesn't have correct field type!" + ); + term![PF_ADD; alpha.clone(), term![PF_NEG; t.clone()]] + }) + .collect(); + + term(PF_MUL, factors) +} + +/// Constructs a term representing the unviersal hash of all terms passed in. +/// Tuples and Arrays are handled recursively +fn universal_hash(terms: impl IntoIterator, beta: &Term) -> Term { + let field = match check(beta) { + Sort::Field(field) => field, + _ => panic!("Beta value for universal hash isn't a field!"), + }; + + // TODO: is extending the iterator here better? + let mut stack: Vec = terms.into_iter().collect(); + let mut results: Vec = Vec::new(); + let mut factor_index = 0; + while !stack.is_empty() { + let curr_term = stack.pop().unwrap(); + match check(&curr_term) { Sort::Tuple(sorts) => { - let tuple_factors = sorts - .iter() - .enumerate() - .map(|(i, _)| { - let mut factors = vec![beta.clone(); i + 2]; - factors.push(cast_to_field( - &term![Op::Field(i); self.val.clone()], - &field, - )); - term![PF_NEG; term(PF_MUL, factors)] - }) - .collect(); - term(PF_ADD, tuple_factors) + stack.extend((0..sorts.len()).map(|i| term![Op::Field(i); curr_term.clone()])); } - _ => term![PF_MUL; cast_to_field(&self.val, &field), beta.clone(), beta.clone()], - }; - //let vals = match self.val { - // Sort::Field(_) => vec![access.val.clone()], - // term![Op::UbvToPf(idx_field.clone()); access.val.clone()] - //} - term![PF_ADD; - alpha.clone(), - // TODO: ignoring is_write for now... - //term![PF_NEG; pf_lit(idx_field_typ.new_v::(access.is_write.get().as_bool_opt().unwrap()))], - term![PF_NEG; term![PF_MUL; self.idx.clone(), beta.clone()]], - term![PF_NEG; val_hash] - ] + Sort::Array(k, _, n) => { + stack.extend((0..n).map(|i| { + let idx = match k.as_ref() { + Sort::Field(field_typ) => Value::Field(field_typ.new_v(i)), + Sort::BitVector(width) => { + Value::BitVector(BitVector::new(rug::Integer::from(i), *width)) + } + _ => panic!("RAM: We don't support arrays with indices other than Field or Bitvector"), + }; + term![Op::Select; curr_term.clone(), term![Op::Const(idx)]] + })); + } + _ => { + //let mut factors = vec![beta.clone(); factor_index]; + if factor_index != 0 { + let beta_power = term(PF_MUL, vec![beta.clone(); factor_index]); + results.push(term![PF_MUL; beta_power, cast_to_field(&curr_term, &field)]); + } else { + results.push(cast_to_field(&curr_term, &field)); + } + //factors.push(cast_to_field(&curr_term, &field)); + factor_index += 1; + } + } } + //println!("Universal hash factors: {:?}", results); + term(PF_ADD, results) } fn cast_to_field(term: &Term, field: &circ_fields::FieldT) -> Term { @@ -121,6 +148,7 @@ impl Ram { &val_name, check(&read_value), Some(crate::ir::proof::PROVER_ID), + 0, // TODO: correct? Some(read_value), ); self.accesses.push(Access::new_read(idx, var.clone())); @@ -171,12 +199,14 @@ impl Ram { &idx_name, self.idx_sort.clone(), Some(crate::ir::proof::PROVER_ID), + 0, // TODO Some(term(Op::NthSmallest(i), idx_terms.clone())), ); let val = computation.new_var( &val_name, self.val_sort.clone(), Some(crate::ir::proof::PROVER_ID), + 0, // TODO Some( term![Op::Select; val_array_term.clone(), term(Op::NthSmallest(i), idx_terms.clone())], ), @@ -435,21 +465,21 @@ impl Encoder { // construct a term to check the __ram_srow values are the same as the // __ram values - let orig_sum_terms = ram1 - .accesses - .iter() - .map(|access| access.universal_hash(&alpha, &beta)) - .collect(); - let orig_prod_term = term(PF_MUL, orig_sum_terms); + let orig_ms_hash = multiset_hash( + ram1.accesses + .iter() + .map(|access| universal_hash(vec![access.idx.clone(), access.val.clone()], &beta)), + &alpha, + ); - let perm_sum_terms = ram2 - .accesses - .iter() - .map(|access| access.universal_hash(&alpha, &beta)) - .collect(); - let perm_prod_term = term(PF_MUL, perm_sum_terms); + let perm_ms_hash = multiset_hash( + ram2.accesses + .iter() + .map(|access| universal_hash(vec![access.idx.clone(), access.val.clone()], &beta)), + &alpha, + ); - term![EQ; orig_prod_term, perm_prod_term] + term![EQ; orig_ms_hash, perm_ms_hash] } fn construct_sorted_check(sorted_ram: &Ram) -> Term { @@ -486,7 +516,7 @@ impl Encoder { term(AND, check_terms) } - fn encode(&self, computation: &mut Computation) { + fn encode(&mut self, computation: &mut Computation) { if self.rams.len() < 2 { return; } @@ -501,33 +531,44 @@ impl Encoder { "Proofs using RAMs must have a boolean output!" ); + // remove ram writes + // TODO: This ONLY works for read-only rams. + for ram in self.rams.iter_mut() { + while eval(&ram.accesses[0].is_write, &fxhash::FxHashMap::default()).as_bool() { + ram.accesses.remove(0); + } + } + + // TODO: actually check that the sorts for ram1 and ram2 are equal... + // TODO: should this actually be verifier_id? + // TODO: assumes that idx_sort is a field + let alpha = computation.new_var( + &format!("__alpha"), + self.rams[0].idx_sort.clone(), + Some(crate::ir::proof::PROVER_ID), + 0, // TODO + None, + ); + let beta = computation.new_var( + &format!("__beta"), + self.rams[0].idx_sort.clone(), + Some(crate::ir::proof::PROVER_ID), + 0, // TODO + None, + ); + for ram in self.rams.iter() { let sorted_ram = ram.sorted_by_index(computation); // TODO: better error handling // TODO: is there a cleaner way to get the field type? let sorted_check = Encoder::construct_sorted_check(&sorted_ram); - // TODO: actually check that the sorts for ram1 and ram2 are equal... - // TODO: should this actually be verifier_id? - // TODO: assumes that idx_sort is a field - let alpha = computation.new_var( - &format!("__alpha_{}", ram.id), - ram.idx_sort.clone(), - Some(crate::ir::proof::PROVER_ID), - None, - ); - let beta = computation.new_var( - &format!("__beta_{}", ram.id), - ram.idx_sort.clone(), - Some(crate::ir::proof::PROVER_ID), - None, - ); - let permutation_check = - Encoder::construct_permutation_check(ram, &sorted_ram, alpha, beta); + Encoder::construct_permutation_check(ram, &sorted_ram, alpha.clone(), beta.clone()); computation.outputs[0] = term![AND; sorted_check, permutation_check, computation.outputs()[0].clone()]; + //println!("ram check: {}", Letified(computation.outputs[0].clone())); } } } @@ -544,8 +585,23 @@ impl Encoder { /// different RAMs with the same init sequence of instructions, this pass will /// not extract **either**. pub fn extract(c: &mut Computation) -> Vec { + //println!( + // "Got computation: {}", + // Letified(term(Op::Tuple, c.outputs.clone())) + //); let mut extractor = Extactor::new(c); extractor.traverse(c); + println!("found {:?} rams", extractor.rams.len()); + for ram in &extractor.rams { + println!("------------"); + println!("ram id: {:?}, len: {:?}", ram.id, ram.accesses.len()); + for access in ram.accesses.iter().take(3) { + println!("{:?}", access); + } + //println!("{:?}", ram.accesses[4]); + //println!("{:?}", ram.accesses[5]); + //println!("{:?}", ram.accesses[6]); + } extractor.rams } @@ -556,7 +612,7 @@ pub fn extract(c: &mut Computation) -> Vec { /// 1. doesn't do a single write, then only reads /// 2. doesn't access every single element in the ram pub fn encode(c: &mut Computation, rams: Vec) { - let encoder = Encoder::new(rams); + let mut encoder = Encoder::new(rams); encoder.encode(c); } diff --git a/src/ir/proof.rs b/src/ir/proof.rs index 8b7aa347..0c664b0a 100644 --- a/src/ir/proof.rs +++ b/src/ir/proof.rs @@ -58,7 +58,7 @@ impl Constraints for Computation { for v in public_inputs { if let Op::Var(n, s) = &v.op { - metadata.new_input(n.to_owned(), None, s.clone()); + metadata.new_input(n.to_owned(), None, 0, s.clone()); } else { panic!() } @@ -66,7 +66,7 @@ impl Constraints for Computation { for v in all_vars { if let Op::Var(n, s) = &v.op { if !public_inputs_set.contains(n) { - metadata.new_input(n.to_owned(), Some(PROVER_ID), s.clone()); + metadata.new_input(n.to_owned(), Some(PROVER_ID), 0, s.clone()); } } else { panic!() diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 68e47bbd..12c1b886 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -259,7 +259,7 @@ impl Op { Op::Update(_) => Some(2), Op::Map(op) => op.arity(), Op::Call(_, args, _) => Some(args.len()), - Op::NthSmallest(_) => None + Op::NthSmallest(_) => None, } } } @@ -1710,6 +1710,9 @@ impl std::iter::Iterator for PostOrderIter { /// A party identifier pub type PartyId = u8; +/// Epoch number for a particular input +pub type Epoch = u8; + #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] /// An IR constraint system. pub struct ComputationMetadata { @@ -1718,11 +1721,29 @@ pub struct ComputationMetadata { /// The next free id. pub next_party_id: PartyId, /// All inputs, including who knows them. If no visibility is set, the input is public. - pub input_vis: FxHashMap)>, + pub input_vis: FxHashMap, /// The inputs for the computation itself (not the precomputation). pub computation_inputs: FxHashSet, } +/// An input to the computation +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct InputMetadata { + term: Term, + visibility: Option, + epoch: Epoch, +} + +impl InputMetadata { + fn new(term: Term, visibility: Option, epoch: Epoch) -> InputMetadata { + InputMetadata { + term, + visibility, + epoch, + } + } +} + impl ComputationMetadata { /// Add a new party to the computation, getting a [PartyId] for them. pub fn add_party(&mut self, name: String) -> PartyId { @@ -1731,17 +1752,24 @@ impl ComputationMetadata { self.next_party_id - 1 } /// Add a new input to the computation, visible to `party`, or public if `party` is [None]. - pub fn new_input(&mut self, input_name: String, party: Option, sort: Sort) { + pub fn new_input( + &mut self, + input_name: String, + party: Option, + epoch: Epoch, + sort: Sort, + ) { let term = leaf_term(Op::Var(input_name.clone(), sort)); debug_assert!( !self.input_vis.contains_key(&input_name) - || self.input_vis.get(&input_name).unwrap().1 == party, + || self.input_vis.get(&input_name).unwrap().visibility == party, "Tried to create input {} (visibility {:?}), but it already existed (visibility {:?})", input_name, party, self.input_vis.get(&input_name).unwrap() ); - self.input_vis.insert(input_name.clone(), (term, party)); + self.input_vis + .insert(input_name.clone(), InputMetadata::new(term, party, epoch)); self.computation_inputs.insert(input_name); } /// Returns None if the value is public. Otherwise, the unique party that knows it. @@ -1754,7 +1782,19 @@ impl ComputationMetadata { input_name, self.input_vis ) }) - .1 + .visibility + } + /// Returns the epoch number for the input. + pub fn get_epoch(&self, input_name: &str) -> Epoch { + self.input_vis + .get(input_name) + .unwrap_or_else(|| { + panic!( + "Missing input {} in inputs{:#?}", + input_name, self.input_vis + ) + }) + .epoch } /// Is this input public? pub fn is_input(&self, input_name: &str) -> bool { @@ -1766,14 +1806,14 @@ impl ComputationMetadata { } /// What sort is this input? pub fn input_sort(&self, input_name: &str) -> Sort { - check(&self.input_vis.get(input_name).unwrap().0) + check(&self.input_vis.get(input_name).unwrap().term) } /// Get all public inputs to the computation itself. /// /// Excludes pre-computation inputs pub fn public_input_names<'a>(&'a self) -> impl Iterator + 'a { self.input_vis.iter().filter_map(move |(name, party)| { - if party.1.is_none() && self.computation_inputs.contains(name) { + if party.visibility.is_none() && self.computation_inputs.contains(name) { Some(name.as_str()) } else { None @@ -1788,22 +1828,20 @@ impl ComputationMetadata { #[allow(clippy::needless_lifetimes)] pub fn public_inputs<'a>(&'a self) -> impl Iterator + 'a { // TODO: check order? - self.input_vis - .iter() - .filter_map(move |(name, (term, vis))| { - if vis.is_none() && self.computation_inputs.contains(name) { - Some(term.clone()) - } else { - None - } - }) + self.input_vis.iter().filter_map(move |(name, input)| { + if input.visibility.is_none() && self.computation_inputs.contains(name) { + Some(input.term.clone()) + } else { + None + } + }) } /// Get all the inputs visible to `party`. pub fn get_inputs_for_party(&self, party: Option) -> FxHashSet { self.input_vis .iter() - .filter_map(|(name, (_, vis))| { - if vis.is_none() || vis == &party { + .filter_map(|(name, input)| { + if input.visibility.is_none() || input.visibility == party { Some(name.clone()) } else { None @@ -1831,7 +1869,7 @@ impl ComputationMetadata { .map(|i| { let vis = visibilities.get(i).map(|p| *party_ids.get(p).unwrap()); let term = inputs.remove(i).unwrap(); - (i.clone(), (term, vis)) + (i.clone(), InputMetadata::new(term, vis, 0)) }) .collect(); ComputationMetadata { @@ -1862,10 +1900,10 @@ impl Display for ComputationMetadata { write!(f, " ({} {})", input, sort)?; } write!(f, ")\n (")?; - for (input, (_, vis)) in &self.input_vis { - if let Some(id) = vis { - let party = self.party_ids.iter().find(|(_, i)| *i == id).unwrap(); - write!(f, " ({} {})", input, party.0)?; + for (input, input_meta) in &self.input_vis { + if let Some(id) = input_meta.visibility { + let party = self.party_ids.iter().find(|(_, i)| **i == id).unwrap(); + write!(f, " ({} {} {})", input, party.0, input_meta.epoch)?; } } write!(f, ")\n)") @@ -1900,10 +1938,38 @@ impl Computation { name: &str, s: Sort, party: Option, + epoch: Epoch, precompute: Option, ) -> Term { - debug!("Var: {} : {} (visibility: {:?})", name, s, party); - self.metadata.new_input(name.to_owned(), party, s.clone()); + self.new_var_epoched(name, s, party, epoch, precompute) + } + + /// Create a new variable, `name: s`, where `val_fn` can be called to get the concrete value, + /// and `public` indicates whether this variable is public in the constraint system. + /// + /// ## Arguments + /// + /// * `name`: the name of the new variable + /// * `s`: its sort + /// * `party`: its visibility: who knows it initially + /// * `epoch`: the epoch the input is chosen in + /// * `precompute`: a precomputation that can determine its value (optional). Note that the + /// precomputation may rely on information that some parties do not have. In this case, + /// those parties will have to provide a value for the variables directly. + pub fn new_var_epoched( + &mut self, + name: &str, + s: Sort, + party: Option, + epoch: Epoch, + precompute: Option, + ) -> Term { + debug!( + "Var: {} : {} (visibility: {:?}, epoch: {:?})", + name, s, epoch, party + ); + self.metadata + .new_input(name.to_owned(), party, epoch, s.clone()); if let Some(p) = precompute { assert_eq!(&s, &check(&p)); self.precomputes.add_output(name.to_owned(), p); @@ -1933,8 +1999,17 @@ impl Computation { _ => panic!("Precomputation for new var {} with term\n\t{}\ninvolves multiple input non-public visibilities:\n\t{:?}", new_input_var, precomp, input_visiblities), } }; + // TODO: Does this make sense? + // Can you ever have a precomputed value with an epoch that isn't 0? + let epoch = { + let input_epochs: FxHashSet = extras::free_variables(precomp.clone()) + .into_iter() + .map(|v| self.metadata.get_epoch(&v)) + .collect(); + *input_epochs.iter().max().unwrap_or(&0) + }; let sort = check(&precomp); - self.new_var(&new_input_var, sort, vis, Some(precomp)); + self.new_var(&new_input_var, sort, vis, epoch, Some(precomp)); } /// Change the sort of a variables diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 76b5a9eb..b8f742b6 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -505,7 +505,7 @@ impl<'src> IrInterp<'src> { fn visibility_list(&self, tt: &TokTree<'src>) -> Vec<(String, String)> { if let List(tts) = tt { - tts.iter() + let res = tts.iter() .map(|tti| match tti { List(ls) => match &ls[..] { [Leaf(Token::Ident, var), Leaf(Token::Ident, party)] => { @@ -517,7 +517,9 @@ impl<'src> IrInterp<'src> { }, _ => panic!("Expected visibility pair, found {}", tti), }) - .collect() + .collect(); + println!("Got {:?}", res); + res } else { panic!("Expected visibility list, found: {}", tt) } @@ -525,6 +527,7 @@ impl<'src> IrInterp<'src> { /// Returns a [ComputationMetadata] and a list of sort bindings to un-bind. fn metadata(&mut self, tt: &TokTree<'src>) -> (ComputationMetadata, Vec>) { + println!("IN META"); if let List(tts) = tt { if tts.is_empty() || tts[0] != Leaf(Token::Ident, b"metadata") { panic!( @@ -541,6 +544,7 @@ impl<'src> IrInterp<'src> { .map(|i| (from_utf8(i).unwrap().into(), self.get_binding(i).clone())) .collect(); let visibilities = self.visibility_list(viss); + println!("VIS: {:?}", visibilities); ( ComputationMetadata::from_parts( parties, diff --git a/src/target/r1cs/marlin.rs b/src/target/r1cs/marlin.rs index 71d44811..7910ea95 100644 --- a/src/target/r1cs/marlin.rs +++ b/src/target/r1cs/marlin.rs @@ -3,7 +3,7 @@ #![allow(unused_imports)] use ark_ff::fields::PrimeField; -use ark_marlin::{IndexProverKey, IndexVerifierKey, Marlin, Proof}; +use ark_marlin::{IndexProverKey, IndexVerifierKey, Marlin, Proof, rng::FiatShamirRng}; use ark_poly::polynomial::univariate::DensePolynomial; use ark_poly_commit::PolynomialCommitment; use ark_relations::{ @@ -80,6 +80,8 @@ fn lc_to_ark(vars: &HashMap, lc: &Lc) -> LinearCombin /// /// Optionally contains a variable value map. This must be populated to use the /// bellman prover. +/// OVERHAUL VLAUE MAP -> no longer makes sense +/// Value => closure? pub struct SynthInput<'a>(&'a R1cs, &'a Option>); impl<'a, F: Field> ConstraintSynthesizer for SynthInput<'a> { @@ -124,10 +126,13 @@ impl<'a, F: Field> ConstraintSynthesizer for SynthInput<'a> { let public = self.0.public_idxs.contains(&i); debug!("var: {}, public: {}", s, public); let v = if public { + // add epoch number cs.new_input_variable(val_f)? } else { cs.new_witness_variable(val_f)? }; + //// + //cs.new_verifier_challenge(name? epoch_num) vars.insert(i, v); } else { debug!("drop dead var: {}", s); @@ -159,7 +164,7 @@ impl<'a, F: Field> ConstraintSynthesizer for SynthInput<'a> { pub fn gen_params< F: PrimeField, PC: PolynomialCommitment>, - D: Digest, + FS: FiatShamirRng, P1: AsRef, P2: AsRef, >( @@ -169,14 +174,15 @@ pub fn gen_params< v_data: &VerifierData, ) -> Result<(), Box> { let rng = &mut rand::thread_rng(); - let srs = Marlin::::universal_setup( - p_data.r1cs.constraints.len() + p_data.r1cs.idxs_signals.len(), + let srs = Marlin::::universal_setup( + // TODO: without *2 this doesn't work for some reason... fix + p_data.r1cs.constraints.len() * 2, p_data.r1cs.idxs_signals.len(), count_non_zeros(&p_data.r1cs.constraints), rng, ) .unwrap(); - let (pk, vk) = Marlin::::index(&srs, SynthInput(&p_data.r1cs, &None)).unwrap(); + let (pk, vk) = Marlin::::index(&srs, SynthInput(&p_data.r1cs, &None)).unwrap(); write_prover_key_and_data(pk_path, &pk, p_data)?; write_verifier_key_and_data(vk_path, &vk, v_data)?; Ok(()) @@ -248,7 +254,7 @@ fn read_verifier_key_and_data< pub fn prove< F: PrimeField, PC: PolynomialCommitment>, - D: Digest, + FS: FiatShamirRng, P1: AsRef, P2: AsRef, >( @@ -272,7 +278,7 @@ pub fn prove< let new_map = prover_data.precompute.eval(inputs_map); prover_data.r1cs.check_all(&new_map); let pf = - Marlin::::prove(&pk, SynthInput(&prover_data.r1cs, &Some(new_map)), rng).unwrap(); + Marlin::::prove(&pk, SynthInput(&prover_data.r1cs, &Some(new_map)), rng).unwrap(); let mut pf_file = File::create(pf_path)?; pf.serialize(&mut pf_file)?; Ok(()) @@ -286,7 +292,7 @@ pub fn prove< pub fn verify< F: PrimeField, PC: PolynomialCommitment>, - D: Digest, + FS: FiatShamirRng, P1: AsRef, P2: AsRef, >( @@ -301,6 +307,6 @@ pub fn verify< let inputs_as_ff: Vec = inputs.into_iter().map(int_to_ff).collect(); let mut pf_file = File::open(pf_path).unwrap(); let pf = Proof::deserialize(&mut pf_file).unwrap(); - Marlin::::verify(&vk, &inputs_as_ff, &pf, rng).unwrap(); + Marlin::::verify(&vk, &inputs_as_ff, &pf, rng).unwrap(); Ok(()) } diff --git a/src/target/r1cs/trans.rs b/src/target/r1cs/trans.rs index a89add77..3a3eaa53 100644 --- a/src/target/r1cs/trans.rs +++ b/src/target/r1cs/trans.rs @@ -69,6 +69,10 @@ impl ToR1cs { /// /// `comp` is a term that computes the value. fn fresh_var(&mut self, ctx: &D, comp: Term, public: bool) -> TermLc { + // max_epoch analysis on comp + // epoch number with signal + // output is defining the computation the prover runs as a precompute + // let n = format!("{}_n{}", ctx, self.next_idx); self.next_idx += 1; debug_assert!(matches!(check(&comp), Sort::Field(_))); diff --git a/third_party/ZoKrates/zokrates_parser/src/zokrates.pest b/third_party/ZoKrates/zokrates_parser/src/zokrates.pest index f8a21f0a..34030eee 100644 --- a/third_party/ZoKrates/zokrates_parser/src/zokrates.pest +++ b/third_party/ZoKrates/zokrates_parser/src/zokrates.pest @@ -43,10 +43,12 @@ struct_field_list = _{(struct_field ~ (NEWLINE+ ~ struct_field)*)? } struct_field = { ty ~ identifier } vis_private_num = @{ "<" ~ ASCII_DIGIT* ~ ">" } -vis_private = {"private" ~ vis_private_num? } +epoch_num = @{ "<" ~ ASCII_DIGIT* ~ ">" } +vis_private = {"private" ~ vis_private_num? ~ epoch_num? } vis_public = {"public"} vis = { vis_private | vis_public } + // Statements statement = { (return_statement // does not require subsequent newline | (iteration_statement diff --git a/third_party/ZoKrates/zokrates_pest_ast/src/lib.rs b/third_party/ZoKrates/zokrates_pest_ast/src/lib.rs index d4b9fcdf..1447ff68 100644 --- a/third_party/ZoKrates/zokrates_pest_ast/src/lib.rs +++ b/third_party/ZoKrates/zokrates_pest_ast/src/lib.rs @@ -353,6 +353,29 @@ mod ast { Private(PrivateVisibility<'ast>), } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::epoch_num))] + pub struct EpochNumber<'ast> { + #[pest_ast(outer(with(span_into_str)))] + pub value: String, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + use std::convert::TryFrom; + impl TryFrom<&EpochNumber<'_>> for u8 { + type Error = String; + + fn try_from(num: &EpochNumber<'_>) -> Result { + let num_val = (&num.value[1..num.value.len() - 1]) + .parse::() + .or_else(|e| { + Err(format!("Bad party number: {}", e)) + }); + num_val + } + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::vis_private_num))] pub struct PrivateNumber<'ast> { @@ -362,6 +385,19 @@ mod ast { pub span: Span<'ast>, } + impl TryFrom<&PrivateNumber<'_>> for u8 { + type Error = String; + + fn try_from(num: &PrivateNumber<'_>) -> Result { + let num_val = (&num.value[1..num.value.len() - 1]) + .parse::() + .or_else(|e| { + Err(format!("Bad party number: {}", e)) + }); + num_val + } + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::vis_public))] pub struct PublicVisibility {} @@ -370,6 +406,7 @@ mod ast { #[pest_ast(rule(Rule::vis_private))] pub struct PrivateVisibility<'ast> { pub number: Option>, + pub epoch: Option>, #[pest_ast(outer())] pub span: Span<'ast>, }