diff --git a/examples/ZoKrates/pf/mem/arr_arr_of_str_of_arr.zok b/examples/ZoKrates/pf/mem/arr_arr_of_str_of_arr.zok new file mode 100644 index 00000000..23a2dbdb --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_arr_of_str_of_arr.zok @@ -0,0 +1,18 @@ +const u32 LEN = 2 +const u32 LEN2 = 100 +const u32 ACCESSES = 37 +const u32 P_ = 8 + +struct Pt { + field[P_] x + field[P_] x2 +} +const Pt [LEN][LEN2] array = [[Pt {x: [0; P_], x2: [0; P_]}; LEN2], ...[[Pt {x: [100; P_], x2: [100; P_]}; LEN2] ; LEN-1]] // 638887 when LEN = 8190 // 63949 when LEN = 819 + +def main(private field[ACCESSES][2] idx) -> field: + field sum = 0 + for u32 i in 0..ACCESSES do + field[2] access = idx[i] + sum = sum + array[access[1]][access[0]].x[0] + endfor + return sum diff --git a/examples/ZoKrates/pf/mem/arr_of_str.zok b/examples/ZoKrates/pf/mem/arr_of_str.zok new file mode 100644 index 00000000..a021dc29 --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str.zok @@ -0,0 +1,18 @@ +const u32 LEN = 6 +const u32 ACCESSES = 3 + +struct Pt { + field x + field y + field z +} +const Pt [LEN] array = [Pt {x: 4, y: 5, z: 6}, ...[Pt {x: 0, y: 1, z: 2}; LEN - 1]] + +def main(private field[ACCESSES] idx) -> field: + field prod = 1 + for u32 i in 0..ACCESSES do + field access = idx[i] + Pt pt = array[access] + prod = prod * pt.x * pt.y * pt.z + endfor + return prod diff --git a/examples/ZoKrates/pf/mem/arr_of_str.zok.pin b/examples/ZoKrates/pf/mem/arr_of_str.zok.pin new file mode 100644 index 00000000..b49ce478 --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str.zok.pin @@ -0,0 +1,7 @@ +(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 +(let ( + (idx.0 #f0) + (idx.1 #f1) + (idx.2 #f2) +) false ; ignored +)) diff --git a/examples/ZoKrates/pf/mem/arr_of_str.zok.vin b/examples/ZoKrates/pf/mem/arr_of_str.zok.vin new file mode 100644 index 00000000..b98ac1cf --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str.zok.vin @@ -0,0 +1,5 @@ +(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 +(let ( + (return #f0) +) false ; ignored +)) diff --git a/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok new file mode 100644 index 00000000..58fa3172 --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok @@ -0,0 +1,21 @@ +const u32 LEN = 4 +const u32 INNER_LEN = 2 +const u32 ACCESSES = 2 + +struct Pt { + field[INNER_LEN] x + field[INNER_LEN] y +} +const Pt [LEN] array = [Pt {x: [0; INNER_LEN], y: [5; INNER_LEN]}, ...[Pt {x: [1; INNER_LEN], y: [2; INNER_LEN]}; LEN - 1]] + +def main(private field[ACCESSES] idx) -> field: + field prod = 1 + for u32 i in 0..ACCESSES do + field access = idx[i] + Pt pt = array[access] + for u32 j in 0..INNER_LEN do + prod = prod * pt.x[j] * pt.y[j] + endfor + endfor + return prod + diff --git a/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.pin b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.pin new file mode 100644 index 00000000..018f4efe --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.pin @@ -0,0 +1,7 @@ +(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 +(let ( + (idx.0 #f0) + (idx.1 #f1) +) false ; ignored +)) + diff --git a/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.vin b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.vin new file mode 100644 index 00000000..492e7010 --- /dev/null +++ b/examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok.vin @@ -0,0 +1,6 @@ +(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 +(let ( + (return #f0) +) false ; ignored +)) + diff --git a/examples/ZoKrates/pf/mem/large_arr_of_str_of_arr.zok b/examples/ZoKrates/pf/mem/large_arr_of_str_of_arr.zok new file mode 100644 index 00000000..a9fc5673 --- /dev/null +++ b/examples/ZoKrates/pf/mem/large_arr_of_str_of_arr.zok @@ -0,0 +1,22 @@ +const u32 LEN = 256 +const u32 INNER_LEN = 8 +const u32 ACCESSES = 10 + +struct Pt { + field[INNER_LEN] x + field[INNER_LEN] y +} +const Pt [LEN] array = [Pt {x: [0; INNER_LEN], y: [5; INNER_LEN]}, ...[Pt {x: [1; INNER_LEN], y: [2; INNER_LEN]}; LEN - 1]] + +def main(private field[ACCESSES] idx) -> field: + field prod = 1 + for u32 i in 0..ACCESSES do + field access = idx[i] + Pt pt = array[access] + for u32 j in 0..INNER_LEN do + prod = prod * pt.x[j] * pt.y[j] + endfor + endfor + return prod + + diff --git a/examples/ZoKrates/pf/mem/two_level_ptr.zok b/examples/ZoKrates/pf/mem/two_level_ptr.zok index fe55c502..f6071a7c 100644 --- a/examples/ZoKrates/pf/mem/two_level_ptr.zok +++ b/examples/ZoKrates/pf/mem/two_level_ptr.zok @@ -1,17 +1,13 @@ const u32 LEN = 4 const u32 ACCESSES = 2 -struct Pt { - field x - field y -} -const Pt [LEN] array = [Pt {x: 0, y:0}, ...[Pt {x: 100, y: 0} ; LEN-1]] +const field[LEN] array = [0, ...[100; LEN-1]] def main(private field[ACCESSES] y) -> field: field result = 0 for u32 i in 0..ACCESSES do - assert(array[y[i]].x == 0) + assert(array[y[i]] == 0) endfor return result diff --git a/examples/circ.rs b/examples/circ.rs index 9a24a909..0f98d135 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -267,16 +267,20 @@ fn main() { opts.push(Opt::ConstantFold(Box::new([]))); opts.push(Opt::ParseCondStores); // Tuples must be eliminated before oblivious array elim - opts.push(Opt::Tuple); opts.push(Opt::ConstantFold(Box::new([]))); - opts.push(Opt::Tuple); opts.push(Opt::Obliv); // The obliv elim pass produces more tuples, that must be eliminated - opts.push(Opt::Tuple); if options.circ.ram.enabled { + // Waksman can only route scalars, so tuple first! + if options.circ.ram.permutation == circ_opt::PermutationStrategy::Waksman { + opts.push(Opt::Tuple); + } opts.push(Opt::PersistentRam); opts.push(Opt::VolatileRam); opts.push(Opt::SkolemizeChallenges); + opts.push(Opt::ScalarizeVars); + opts.push(Opt::ConstantFold(Box::new([]))); + opts.push(Opt::Obliv); } opts.push(Opt::LinearScan); // The linear scan pass produces more tuples, that must be eliminated diff --git a/scripts/compiler_asymptotics.zsh b/scripts/compiler_asymptotics.zsh new file mode 100755 index 00000000..7a238187 --- /dev/null +++ b/scripts/compiler_asymptotics.zsh @@ -0,0 +1,40 @@ +#!/usr/bin/env zsh +set -ex + +function usage { + echo "Usage: $0 COMPILER_COMMAND TEMPLATE PATTERN REPLACEMENTS..." + exit 2 +} + +compiler_command=($(eval echo $1)) +template_file=$2 +pattern=$3 +replacements=(${@:4}) + +[[ ! -z $compiler_command ]] || (echo "Empty compiler command" && usage) +if [[ ! -a $template_file ]] +then + for arg in $compiler_command + do + if [[ $arg =~ .*.zok ]] + then + echo "template $arg" + template_file=$arg + fi + done +fi +[[ -a $template_file ]] || (echo "No file at $template_file" && usage) +[[ ! -z $pattern ]] || (echo "Empty pattern" && usage) +[[ ! -z $replacements ]] || (echo "Empty replacements" && usage) + +echo $replacements + +for replacement in $replacements +do + t=$(mktemp compiler_asymptotics_XXXXXXXX.zok) + cat $template_file | sed "s/$pattern/$replacement/g" > $t + instantiated_command=$(echo $compiler_command | sed "s/$template_file/$t/") + echo $instantiated_command + rm $t +done + diff --git a/scripts/ram_test.zsh b/scripts/ram_test.zsh index 3b45d324..70424cf5 100755 --- a/scripts/ram_test.zsh +++ b/scripts/ram_test.zsh @@ -25,7 +25,10 @@ function ram_test { ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" +ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok mirage "" ram_test ./examples/ZoKrates/pf/mem/volatile.zok mirage "" ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok mirage "" +ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok mirage "" +ram_test ./examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok mirage "" diff --git a/scripts/test_c_r1cs.zsh b/scripts/test_c_r1cs.zsh new file mode 100755 index 00000000..ca002438 --- /dev/null +++ b/scripts/test_c_r1cs.zsh @@ -0,0 +1,21 @@ +#!/usr/bin/env zsh + +set -ex + +# cargo build --release --features lp,r1cs,smt,zok --example circ + +MODE=debug # release or debug +BIN=./target/$MODE/examples/circ +ZK_BIN=./target/$MODE/examples/zk + +# Test prove workflow, given an example name +function c_pf_test { + proof_impl=groth16 + ex_name=$1 + $BIN examples/C/r1cs/$ex_name.c r1cs --action setup --proof-impl $proof_impl + $ZK_BIN --inputs examples/C/r1cs/$ex_name.c.pin --action prove --proof-impl $proof_impl + $ZK_BIN --inputs examples/C/r1cs/$ex_name.c.vin --action verify --proof-impl $proof_impl + rm -rf P V pi +} + +c_pf_test add diff --git a/src/ir/opt/mem/lin.rs b/src/ir/opt/mem/lin.rs index f6ccdbcf..c6920a17 100644 --- a/src/ir/opt/mem/lin.rs +++ b/src/ir/opt/mem/lin.rs @@ -17,6 +17,12 @@ fn arr_val_to_tup(v: &Value) -> Value { } vec }), + Value::Tuple(vs) => Value::Tuple( + vs.iter() + .map(arr_val_to_tup) + .collect::>() + .into_boxed_slice(), + ), v => v.clone(), } } @@ -29,7 +35,7 @@ impl RewritePass for Linearizer { rewritten_children: F, ) -> Option { match &orig.op() { - Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))), + Op::Const(v) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))), Op::Var(name, Sort::Array(..)) => { let precomp = extras::array_to_tuple(orig); let new_name = format!("{name}.tup"); diff --git a/src/ir/opt/mem/obliv.rs b/src/ir/opt/mem/obliv.rs index c3560c85..aa0e54ea 100644 --- a/src/ir/opt/mem/obliv.rs +++ b/src/ir/opt/mem/obliv.rs @@ -13,11 +13,14 @@ //! //! So, essentially, what's going on is that T maps each term t to an (approximate) analysis of t //! that indicates which accesses can be perfectly resolved. +//! +//! We could make the analysis more precise (and/or efficient) with a better data structure for +//! tracking information about value locations. use crate::ir::term::extras::as_uint_constant; use crate::ir::term::*; -use log::{debug, trace}; +use log::trace; #[derive(Default)] struct OblivRewriter { @@ -30,6 +33,7 @@ fn suitable_const(t: &Term) -> bool { } impl OblivRewriter { + /// Get, prefering tuple if possible. fn get_t(&self, t: &Term) -> &Term { self.tups.get(t).unwrap_or(self.terms.get(t).unwrap()) } @@ -57,7 +61,7 @@ impl OblivRewriter { ( if let Some(aa) = self.tups.get(a) { if suitable_const(i) { - debug!("simplify store {}", i); + trace!("simplify store {}", i); Some(term![Op::Update(get_const(i)); aa.clone(), self.get_t(v).clone()]) } else { None @@ -73,7 +77,7 @@ impl OblivRewriter { let i = &t.cs()[1]; if let Some(aa) = self.tups.get(a) { if suitable_const(i) { - debug!("simplify select {}", i); + trace!("simplify select {}", i); let tt = term![Op::Field(get_const(i)); aa.clone()]; ( Some(tt.clone()), @@ -115,7 +119,37 @@ impl OblivRewriter { }, ) } - Op::Tuple => panic!("Tuple in obliv"), + Op::Tuple => ( + if t.cs().iter().all(|c| self.tups.contains_key(c)) { + Some(term( + Op::Tuple, + t.cs() + .iter() + .map(|c| self.tups.get(c).unwrap().clone()) + .collect(), + )) + } else { + None + }, + None, + ), + Op::Field(i) => ( + if t.cs().iter().all(|c| self.tups.contains_key(c)) { + Some(term_c![Op::Field(*i); self.get_t(&t.cs()[0])]) + } else { + None + }, + None, + ), + Op::Update(i) => ( + if t.cs().iter().all(|c| self.tups.contains_key(c)) { + Some(term_c![Op::Update(*i); self.get_t(&t.cs()[0]), self.get_t(&t.cs()[1])]) + } else { + None + }, + None, + ), + //Op::Tuple => panic!("Tuple in obliv"), _ => (None, None), }; if let Some(tup) = tup_opt { diff --git a/src/ir/opt/mem/ram.rs b/src/ir/opt/mem/ram.rs index ce648bce..3d7b6e1e 100644 --- a/src/ir/opt/mem/ram.rs +++ b/src/ir/opt/mem/ram.rs @@ -26,6 +26,8 @@ pub mod volatile; struct Access { /// The value read or (conditionally) written. pub val: Term, + /// A (field) hash of the value read or (conditionally) written. + pub val_hash: Option, /// The index/address. pub idx: Term, /// The time of this access. @@ -108,19 +110,40 @@ impl AccessCfg { true, ) } - fn len(&self) -> usize { - if self.create { - 6 - } else { - 5 + fn val_sort_len(s: &Sort) -> usize { + match s { + Sort::Tuple(t) => t.iter().map(Self::val_sort_len).sum(), + Sort::Array(_, v, size) => *size * Self::val_sort_len(v), + _ => 1, } } + fn len(&self, s: &Sort) -> usize { + (if self.create { 5 } else { 4 }) + Self::val_sort_len(s) + } fn bool2pf(&self, t: Term) -> Term { term![Op::Ite; t, self.one.clone(), self.zero.clone()] } fn pf_neg(&self, t: Term) -> Term { term![PF_ADD; self.one.clone(), term![PF_NEG; t]] } + fn pf_lit(&self, i: usize) -> Term { + pf_lit(self.field.new_v(i)) + } +} + +fn scalar_to_field(scalar: &Term, c: &AccessCfg) -> Term { + match check(scalar) { + Sort::Field(f) => { + if f == c.field { + scalar.clone() + } else { + panic!("Cannot convert scalar of field {} to field {}", f, c.field) + } + } + Sort::Bool => c.bool2pf(scalar.clone()), + Sort::BitVector(_) => term![Op::UbvToPf(c.field.clone()); scalar.clone()], + s => panic!("non-scalar sort {}", s), + } } /// A bit encoded in the field. @@ -163,6 +186,7 @@ impl Access { fn new_read(f: &AccessCfg, idx: Term, val: Term, time: Term) -> Self { Self { val, + val_hash: None, idx, time, write: FieldBit::from_bool_lit(f, false), @@ -173,6 +197,7 @@ impl Access { fn new_write(f: &AccessCfg, idx: Term, val: Term, active: Term, time: Term) -> Self { Self { val, + val_hash: None, idx, time, write: FieldBit::from_bool_lit(f, true), @@ -183,6 +208,7 @@ impl Access { fn new_init(f: &AccessCfg, idx: Term, val: Term) -> Self { Self { val, + val_hash: None, idx, time: f.zero.clone(), write: FieldBit::from_bool_lit(f, true), @@ -191,56 +217,127 @@ impl Access { } } - fn field_names(c: &AccessCfg, order: Order) -> &'static [&'static str] { + fn field_names(c: &AccessCfg, sort: &Sort, order: Order) -> Vec { + let mut out = Vec::new(); match order { Order::Hash => { + Self::sort_subnames(sort, "v", &mut out); + out.push("i".into()); + out.push("t".into()); + out.push("w".into()); + out.push("a".into()); if c.create { - &["v", "i", "t", "w", "a", "c"] - } else { - &["v", "i", "t", "w", "a"] + out.push("c".into()); } } // dead code, but for clarity... Order::Sort => { + out.push("i".into()); + out.push("t".into()); if c.create { - &["i", "t", "c", "v", "w", "a"] - } else { - &["i", "t", "v", "w", "a"] + out.push("c".into()); + } + Self::sort_subnames(sort, "v", &mut out); + out.push("w".into()); + out.push("a".into()); + } + } + out + } + + fn sort_subnames(sort: &Sort, prefix: &str, out: &mut Vec) { + match sort { + Sort::Field(_) | Sort::Bool | Sort::BitVector(_) => out.push(prefix.into()), + Sort::Tuple(ss) => { + for (i, s) in ss.iter().enumerate() { + Self::sort_subnames(s, &format!("{}_{}", prefix, i), out); } } + Sort::Array(_, v, size) => { + for i in 0..*size { + Self::sort_subnames(v, &format!("{}_{}", prefix, i), out); + } + } + _ => unreachable!(), + } + } + + fn val_to_field_elements(val: &Term, c: &AccessCfg, out: &mut Vec) { + match check(val) { + Sort::Field(_) | Sort::Bool | Sort::BitVector(_) => out.push(scalar_to_field(val, c)), + Sort::Tuple(ss) => { + for i in 0..ss.len() { + Self::val_to_field_elements(&term![Op::Field(i); val.clone()], c, out); + } + } + Sort::Array(_, _, size) => { + for i in 0..size { + Self::val_to_field_elements( + &term![Op::Select; val.clone(), c.pf_lit(i)], + c, + out, + ); + } + } + _ => unreachable!(), + } + } + + fn val_from_field_elements_trusted(sort: &Sort, next: &mut impl FnMut() -> Term) -> Term { + match sort { + Sort::Field(_) => next().clone(), + Sort::Bool => term![Op::PfToBoolTrusted; next().clone()], + Sort::BitVector(w) => term![Op::PfToBv(*w); next().clone()], + Sort::Tuple(ss) => term( + Op::Tuple, + ss.iter() + .map(|s| Self::val_from_field_elements_trusted(s, next)) + .collect(), + ), + Sort::Array(k, v, size) => term( + Op::Array(*k.clone(), *v.clone()), + (0..*size) + .map(|_| Self::val_from_field_elements_trusted(v, next)) + .collect(), + ), + _ => unreachable!(), } } fn to_field_elems(&self, c: &AccessCfg, order: Order) -> Vec { + let mut out = Vec::new(); match order { Order::Hash => { - let mut out = vec![ - self.val.clone(), - self.idx.clone(), - self.time.clone(), - self.write.f.clone(), - self.active.f.clone(), - ]; - if c.create { - out.push(self.create.f.clone()) - } - out - } - Order::Sort => { - let mut out = vec![self.idx.clone(), self.time.clone()]; - if c.create { - out.push(self.create.f.clone()) - } - out.push(self.val.clone()); + Self::val_to_field_elements(&self.val, c, &mut out); + out.push(self.idx.clone()); + out.push(self.time.clone()); + out.push(self.write.f.clone()); + out.push(self.active.f.clone()); + if c.create { + out.push(self.create.f.clone()) + } + } + Order::Sort => { + out.push(self.idx.clone()); + out.push(self.time.clone()); + if c.create { + out.push(self.create.f.clone()) + } + Self::val_to_field_elements(&self.val, c, &mut out); out.push(self.write.f.clone()); out.push(self.active.f.clone()); - out } } + out } - fn from_field_elems_trusted(elems: Vec, c: &AccessCfg, order: Order) -> Self { - debug_assert_eq!(elems.len(), c.len()); + fn from_field_elems_trusted( + elems: Vec, + val_sort: &Sort, + c: &AccessCfg, + order: Order, + ) -> Self { + debug_assert_eq!(elems.len(), c.len(val_sort)); let mut elems = elems.into_iter(); let mut next = || { let t = elems.next().unwrap(); @@ -249,7 +346,8 @@ impl Access { }; match order { Order::Hash => Self { - val: next(), + val: Self::val_from_field_elements_trusted(val_sort, &mut next), + val_hash: None, idx: next(), time: next(), write: FieldBit::from_trusted_field(c, next()), @@ -261,6 +359,7 @@ impl Access { }, }, Order::Sort => Self { + val_hash: None, idx: next(), time: next(), create: if c.create { @@ -268,16 +367,26 @@ impl Access { } else { FieldBit::from_bool_lit(c, false) }, - val: next(), + val: Self::val_from_field_elements_trusted(val_sort, &mut next), write: FieldBit::from_trusted_field(c, next()), active: FieldBit::from_trusted_field(c, next()), }, } } - fn universal_hash(&self, c: &AccessCfg, hasher: &hash::UniversalHasher) -> Term { - assert_eq!(hasher.len(), c.len()); - hasher.hash(self.to_field_elems(c, Order::Hash)) + fn universal_hash( + &self, + c: &AccessCfg, + val_sort: &Sort, + hasher: &hash::UniversalHasher, + ) -> (Term, Term) { + assert_eq!(hasher.len(), c.len(val_sort)); + let mut val_elems = Vec::new(); + Self::val_to_field_elements(&self.val, c, &mut val_elems); + ( + hasher.hash(self.to_field_elems(c, Order::Hash)), + hasher.hash(val_elems), + ) } fn to_field_tuple(&self, c: &AccessCfg) -> Term { @@ -287,16 +396,17 @@ impl Access { fn declare_trusted( c: &AccessCfg, mut declare_var: impl FnMut(&str, Term) -> Term, + val_sort: &Sort, value_tuple: Term, ) -> Self { let mut declare_field = |name: &str, idx: usize| declare_var(name, term![Op::Field(idx); value_tuple.clone()]); - let elems = Self::field_names(c, Order::Sort) + let elems = Self::field_names(c, val_sort, Order::Sort) .iter() .enumerate() .map(|(idx, name)| declare_field(name, idx)) .collect(); - Self::from_field_elems_trusted(elems, c, Order::Sort) + Self::from_field_elems_trusted(elems, val_sort, c, Order::Sort) } } @@ -318,8 +428,10 @@ pub struct Ram { boundary_conditions: BoundaryConditions, /// The unique id of this RAM id: usize, - /// The sort for times, indices, and values. + /// The sort for times and indices. sort: Sort, + /// The sort for values. + val_sort: Sort, /// The size size: usize, /// The list of accesses (in access order) @@ -332,11 +444,25 @@ pub struct Ram { cfg: AccessCfg, } +#[allow(dead_code)] +/// Are terms of sort `s` hashable using a UHF keyed by field type `f`. +fn hashable(s: &Sort, f: &FieldT) -> bool { + match s { + Sort::Field(ff) => f == ff, + Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)), + Sort::BitVector(_) => true, + Sort::Bool => true, + Sort::Array(_k, v, size) => *size < 20 && hashable(v, f), + _ => false, + } +} + impl Ram { fn new( id: usize, size: usize, cfg: AccessCfg, + val_sort: Sort, boundary_conditions: BoundaryConditions, ) -> Self { assert!(!matches!( @@ -347,6 +473,7 @@ impl Ram { boundary_conditions, id, sort: Sort::Field(cfg.field.clone()), + val_sort, cfg, accesses: Default::default(), size, @@ -369,6 +496,21 @@ impl Ram { } } } + #[track_caller] + #[allow(unused_variables)] + /// Assert that `other` is hashable using the field of `self`. + fn assert_hashable(&self, other: &Term) { + #[cfg(debug_assertions)] + { + let s = check(other); + if !hashable(&s, &self.cfg.field) { + panic!( + "RAM field of sort {} is not hashable with field {}", + s, self.cfg.field + ); + } + } + } fn next_time_term(&mut self) -> Term { let t = self.sort.nth_elem(self.next_time); if !self.end_of_time { @@ -379,12 +521,12 @@ impl Ram { fn new_read(&mut self, idx: Term, computation: &mut Computation, read_value: Term) -> Term { let val_name = format!("__ram{}_read_v{}", self.id, self.accesses.len()); self.assert_field(&idx); - self.assert_field(&read_value); + self.assert_hashable(&read_value); debug_assert!(!self.end_of_time); let var = computation.new_var( &val_name, - self.sort.clone(), + self.val_sort.clone(), Some(crate::ir::proof::PROVER_ID), Some(read_value), ); @@ -396,7 +538,7 @@ impl Ram { } fn new_final_read(&mut self, idx: Term, val: Term) { self.assert_field(&idx); - self.assert_field(&val); + self.assert_hashable(&val); self.end_of_time = true; let time = self.next_time_term(); trace!( @@ -411,7 +553,7 @@ impl Ram { fn new_write(&mut self, idx: Term, val: Term, guard: Term) { debug_assert!(!self.end_of_time); self.assert_field(&idx); - self.assert_field(&val); + self.assert_hashable(&val); debug_assert_eq!(&check(&guard), &Sort::Bool); let time = self.next_time_term(); trace!( @@ -426,7 +568,7 @@ impl Ram { } fn new_init(&mut self, idx: Term, val: Term) { self.assert_field(&idx); - self.assert_field(&val); + self.assert_hashable(&val); self.end_of_time = true; trace!("init: ops: idx {}, val {}", idx.op(), val.op()); self.accesses diff --git a/src/ir/opt/mem/ram/checker.rs b/src/ir/opt/mem/ram/checker.rs index 5dcaabff..501db174 100644 --- a/src/ir/opt/mem/ram/checker.rs +++ b/src/ir/opt/mem/ram/checker.rs @@ -25,6 +25,7 @@ pub fn check_ram(c: &mut Computation, ram: Ram) { let id = ram.id; let ns = Namespace::new().subspace(&format!("ram{id}")); let f_s = Sort::Field(f.clone()); + let v_s = ram.val_sort.clone(); let mut new_var = |name: &str, val: Term| c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val)); @@ -33,9 +34,16 @@ pub fn check_ram(c: &mut Computation, ram: Ram) { let sorted_accesses = if ram.cfg.waksman { let mut new_bit_var = |name: &str, val: Term| c.new_var(&ns.fqn(name), Sort::Bool, PROVER_VIS, Some(val)); - permutation::waksman(&ram.accesses, &ram.cfg, &mut new_bit_var) + permutation::waksman(&ram.accesses, &ram.cfg, &v_s, &mut new_bit_var) } else { - permutation::msh(&ram.accesses, &ns, &ram.cfg, &mut new_var, &mut assertions) + permutation::msh( + &ram.accesses, + &ns, + &ram.cfg, + &mut new_var, + &v_s, + &mut assertions, + ) }; // (2) check the sorted transcript @@ -64,22 +72,21 @@ pub fn check_ram(c: &mut Computation, ram: Ram) { } let mut deltas = Vec::new(); + // To: check some condition on the start? for j in 0..(n - 1) { // previous entry let i = &accs[j].idx; let t = &accs[j].time; - let v = &accs[j].val; + let v = accs[j].val_hash.as_ref().expect("missing value hash"); // this entry let i_n = &accs[j + 1].idx; let t_n = &accs[j + 1].time; - let v_n = &accs[j + 1].val; + let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash"); let c_n = &accs[j + 1].create; let w_n = &accs[j + 1].write; let v_p = if only_init { v.clone() - } else if j == 0 { - default.clone() } else { term![ITE; c_n.b.clone(), default.clone(), v.clone()] }; diff --git a/src/ir/opt/mem/ram/checker/permutation.rs b/src/ir/opt/mem/ram/checker/permutation.rs index 47a22d4b..b88fbdf4 100644 --- a/src/ir/opt/mem/ram/checker/permutation.rs +++ b/src/ir/opt/mem/ram/checker/permutation.rs @@ -8,6 +8,7 @@ use std::collections::VecDeque; pub(super) fn waksman( accesses: &VecDeque, cfg: &AccessCfg, + val_sort: &Sort, new_var: &mut impl FnMut(&str, Term) -> Term, ) -> Vec { let f = &cfg.field; @@ -37,7 +38,14 @@ pub(super) fn waksman( let elems = (0..len) .map(|idx| term![Op::Field(idx); v.clone()]) .collect(); - Access::from_field_elems_trusted(elems, cfg, Order::Sort) + let mut access = Access::from_field_elems_trusted(elems, val_sort, cfg, Order::Sort); + assert!( + check(&access.val).is_scalar(), + "Waksman only supports scalar values; got {}", + check(&access.val) + ); + access.val_hash = Some(super::scalar_to_field(&access.val, cfg)); + access }) .collect(); sorted_accesses @@ -57,6 +65,7 @@ pub(super) fn msh( ns: &Namespace, cfg: &AccessCfg, new_var: &mut impl FnMut(&str, Term) -> Term, + val_sort: &Sort, assertions: &mut Vec, ) -> Vec { let f = &cfg.field; @@ -66,13 +75,14 @@ pub(super) fn msh( let sorted_field_tuple_values: Vec = unmake_array( term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())], ); - let sorted_accesses: Vec = sorted_field_tuple_values + let mut sorted_accesses: Vec = sorted_field_tuple_values .into_iter() .enumerate() .map(|(i, v)| { Access::declare_trusted( cfg, |name: &str, term: Term| new_var(&format!("sort_a{i}_{name}"), term), + val_sort, v, ) }) @@ -81,20 +91,23 @@ pub(super) fn msh( .into_iter() .chain(sorted_accesses.iter().map(|a| a.to_field_tuple(cfg))) .collect(); - let uhf = UniversalHasher::new(ns.fqn("uhf_key"), f, uhf_inputs.clone(), cfg.len()); + let uhf = UniversalHasher::new(ns.fqn("uhf_key"), f, uhf_inputs.clone(), cfg.len(val_sort)); let msh = MsHasher::new(ns.fqn("ms_hash_key"), f, uhf_inputs); // (2) permutation argument let univ_hashes_unsorted: Vec = accesses .iter() - .map(|a| a.universal_hash(cfg, &uhf)) + .map(|a| a.universal_hash(cfg, val_sort, &uhf).0) .collect(); - let univ_hashes_sorted: Vec = sorted_accesses + let (univ_hashes_sorted, val_hashes_sorted): (Vec, Vec) = sorted_accesses .iter() - .map(|a| a.universal_hash(cfg, &uhf)) - .collect(); + .map(|a| a.universal_hash(cfg, val_sort, &uhf)) + .unzip(); let ms_hash_passes = term![EQ; msh.hash(univ_hashes_unsorted), msh.hash(univ_hashes_sorted)]; assertions.push(ms_hash_passes); + for (access, hash) in sorted_accesses.iter_mut().zip(val_hashes_sorted) { + access.val_hash = Some(hash); + } sorted_accesses } diff --git a/src/ir/opt/mem/ram/persistent.rs b/src/ir/opt/mem/ram/persistent.rs index ddc32249..e1cb4d0b 100644 --- a/src/ir/opt/mem/ram/persistent.rs +++ b/src/ir/opt/mem/ram/persistent.rs @@ -55,7 +55,13 @@ pub fn persistent_to_ram(c: &mut Computation, cfg: &AccessCfg) -> Vec { c.metadata.add_commitment(final_names); let boundary_conditions = BoundaryConditions::Persistent(terms, final_terms); - let ram = Ram::new(i, size, cfg.clone(), boundary_conditions); + let ram = Ram::new( + i, + size, + cfg.clone(), + Sort::Field(cfg.field.clone()), + boundary_conditions, + ); term_rams.insert(init_term, i); rams.push(ram); diff --git a/src/ir/opt/mem/ram/volatile.rs b/src/ir/opt/mem/ram/volatile.rs index b915390b..690c69a8 100644 --- a/src/ir/opt/mem/ram/volatile.rs +++ b/src/ir/opt/mem/ram/volatile.rs @@ -34,12 +34,24 @@ struct ArrayGraph { ram_terms: TermSet, } +/// Are terms of sort `s` hashable using a UHF keyed by field type `f`. +fn hashable(s: &Sort, f: &FieldT) -> bool { + match s { + Sort::Field(ff) => f == ff, + Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)), + Sort::BitVector(_) => true, + Sort::Bool => true, + Sort::Array(_k, v, size) => *size < 20 && hashable(v, f), + _ => false, + } +} + /// Does this array have a sort compatible with our RAM machinery? fn right_sort(t: &Term, f: &FieldT) -> bool { let s = check(t); if let Sort::Array(k, v, _) = &s { - if let (Sort::Field(k), Sort::Field(v)) = (&**k, &**v) { - v == f && k == f + if let Sort::Field(k) = &**k { + k == f && hashable(v, f) } else { false } @@ -168,9 +180,9 @@ impl Extactor { // create a default RAM from `t`'s sort. let id = self.rams.len(); let t_sort = check(t); - let (key_sort, _, size) = t_sort.as_array(); + let (key_sort, val_sort, size) = t_sort.as_array(); let def = BoundaryConditions::Default(key_sort.default_term()); - let mut ram = Ram::new(id, size, self.cfg.clone(), def); + let mut ram = Ram::new(id, size, self.cfg.clone(), val_sort.clone(), def); // update with details specific to `t`. match &t.op() { diff --git a/src/ir/opt/tuple.rs b/src/ir/opt/tuple.rs index 72de420e..b0c4935b 100644 --- a/src/ir/opt/tuple.rs +++ b/src/ir/opt/tuple.rs @@ -224,9 +224,13 @@ fn untuple_value(v: &Value) -> Value { } } +fn find_tuple_term(t: Term) -> Option { + PostOrderIter::new(t).find(|c| matches!(check(c), Sort::Tuple(..))) +} + #[allow(dead_code)] fn tuple_free(t: Term) -> bool { - PostOrderIter::new(t).all(|c| !matches!(check(&c), Sort::Tuple(..))) + find_tuple_term(t).is_none() } /// Run the tuple elimination pass. @@ -301,6 +305,8 @@ pub fn eliminate_tuples(cs: &mut Computation) { .collect(); #[cfg(debug_assertions)] for o in &cs.outputs { - assert!(tuple_free(o.clone())); + if let Some(t) = find_tuple_term(o.clone()) { + panic!("Tuple term {}", t) + } } } diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 6fdf4e42..c96d4153 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -2102,7 +2102,7 @@ impl Computation { ); self.metadata.new_input(name.to_owned(), party, s.clone()); if let Some(p) = precompute { - assert_eq!(&s, &check(&p)); + assert_eq!(&s, &check(&p), "precompute {} doesn't match sort {}", p, s); self.precomputes.add_output(name.to_owned(), p); } leaf_term(Op::Var(name.to_owned(), s))