RAM for non-scalar values (#174)

It is very naive. It assumes that any top-level array should be represented as a RAM, and that all internal structure should be unfolded.
This commit is contained in:
Alex Ozdemir
2023-10-17 22:04:38 -07:00
committed by GitHub
parent 4c5dafee95
commit 805a7f424f
22 changed files with 475 additions and 81 deletions

View File

@@ -0,0 +1,18 @@
const u32 LEN = 2
const u32 LEN2 = 100
const u32 ACCESSES = 37
const u32 P_ = 8
struct Pt {
field[P_] x
field[P_] x2
}
const Pt [LEN][LEN2] array = [[Pt {x: [0; P_], x2: [0; P_]}; LEN2], ...[[Pt {x: [100; P_], x2: [100; P_]}; LEN2] ; LEN-1]] // 638887 when LEN = 8190 // 63949 when LEN = 819
def main(private field[ACCESSES][2] idx) -> field:
field sum = 0
for u32 i in 0..ACCESSES do
field[2] access = idx[i]
sum = sum + array[access[1]][access[0]].x[0]
endfor
return sum

View File

@@ -0,0 +1,18 @@
const u32 LEN = 6
const u32 ACCESSES = 3
struct Pt {
field x
field y
field z
}
const Pt [LEN] array = [Pt {x: 4, y: 5, z: 6}, ...[Pt {x: 0, y: 1, z: 2}; LEN - 1]]
def main(private field[ACCESSES] idx) -> field:
field prod = 1
for u32 i in 0..ACCESSES do
field access = idx[i]
Pt pt = array[access]
prod = prod * pt.x * pt.y * pt.z
endfor
return prod

View File

@@ -0,0 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(idx.0 #f0)
(idx.1 #f1)
(idx.2 #f2)
) false ; ignored
))

View File

@@ -0,0 +1,5 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(return #f0)
) false ; ignored
))

View File

@@ -0,0 +1,21 @@
const u32 LEN = 4
const u32 INNER_LEN = 2
const u32 ACCESSES = 2
struct Pt {
field[INNER_LEN] x
field[INNER_LEN] y
}
const Pt [LEN] array = [Pt {x: [0; INNER_LEN], y: [5; INNER_LEN]}, ...[Pt {x: [1; INNER_LEN], y: [2; INNER_LEN]}; LEN - 1]]
def main(private field[ACCESSES] idx) -> field:
field prod = 1
for u32 i in 0..ACCESSES do
field access = idx[i]
Pt pt = array[access]
for u32 j in 0..INNER_LEN do
prod = prod * pt.x[j] * pt.y[j]
endfor
endfor
return prod

View File

@@ -0,0 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(idx.0 #f0)
(idx.1 #f1)
) false ; ignored
))

View File

@@ -0,0 +1,6 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(return #f0)
) false ; ignored
))

View File

@@ -0,0 +1,22 @@
const u32 LEN = 256
const u32 INNER_LEN = 8
const u32 ACCESSES = 10
struct Pt {
field[INNER_LEN] x
field[INNER_LEN] y
}
const Pt [LEN] array = [Pt {x: [0; INNER_LEN], y: [5; INNER_LEN]}, ...[Pt {x: [1; INNER_LEN], y: [2; INNER_LEN]}; LEN - 1]]
def main(private field[ACCESSES] idx) -> field:
field prod = 1
for u32 i in 0..ACCESSES do
field access = idx[i]
Pt pt = array[access]
for u32 j in 0..INNER_LEN do
prod = prod * pt.x[j] * pt.y[j]
endfor
endfor
return prod

View File

@@ -1,17 +1,13 @@
const u32 LEN = 4
const u32 ACCESSES = 2
struct Pt {
field x
field y
}
const Pt [LEN] array = [Pt {x: 0, y:0}, ...[Pt {x: 100, y: 0} ; LEN-1]]
const field[LEN] array = [0, ...[100; LEN-1]]
def main(private field[ACCESSES] y) -> field:
field result = 0
for u32 i in 0..ACCESSES do
assert(array[y[i]].x == 0)
assert(array[y[i]] == 0)
endfor
return result

View File

@@ -267,16 +267,20 @@ fn main() {
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::ParseCondStores);
// Tuples must be eliminated before oblivious array elim
opts.push(Opt::Tuple);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Tuple);
opts.push(Opt::Obliv);
// The obliv elim pass produces more tuples, that must be eliminated
opts.push(Opt::Tuple);
if options.circ.ram.enabled {
// Waksman can only route scalars, so tuple first!
if options.circ.ram.permutation == circ_opt::PermutationStrategy::Waksman {
opts.push(Opt::Tuple);
}
opts.push(Opt::PersistentRam);
opts.push(Opt::VolatileRam);
opts.push(Opt::SkolemizeChallenges);
opts.push(Opt::ScalarizeVars);
opts.push(Opt::ConstantFold(Box::new([])));
opts.push(Opt::Obliv);
}
opts.push(Opt::LinearScan);
// The linear scan pass produces more tuples, that must be eliminated

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env zsh
set -ex
function usage {
echo "Usage: $0 COMPILER_COMMAND TEMPLATE PATTERN REPLACEMENTS..."
exit 2
}
compiler_command=($(eval echo $1))
template_file=$2
pattern=$3
replacements=(${@:4})
[[ ! -z $compiler_command ]] || (echo "Empty compiler command" && usage)
if [[ ! -a $template_file ]]
then
for arg in $compiler_command
do
if [[ $arg =~ .*.zok ]]
then
echo "template $arg"
template_file=$arg
fi
done
fi
[[ -a $template_file ]] || (echo "No file at $template_file" && usage)
[[ ! -z $pattern ]] || (echo "Empty pattern" && usage)
[[ ! -z $replacements ]] || (echo "Empty replacements" && usage)
echo $replacements
for replacement in $replacements
do
t=$(mktemp compiler_asymptotics_XXXXXXXX.zok)
cat $template_file | sed "s/$pattern/$replacement/g" > $t
instantiated_command=$(echo $compiler_command | sed "s/$template_file/$t/")
echo $instantiated_command
rm $t
done

View File

@@ -25,7 +25,10 @@ function ram_test {
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split"
ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/volatile.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/volatile_struct.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/arr_of_str.zok mirage ""
ram_test ./examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok mirage ""

21
scripts/test_c_r1cs.zsh Executable file
View File

@@ -0,0 +1,21 @@
#!/usr/bin/env zsh
set -ex
# cargo build --release --features lp,r1cs,smt,zok --example circ
MODE=debug # release or debug
BIN=./target/$MODE/examples/circ
ZK_BIN=./target/$MODE/examples/zk
# Test prove workflow, given an example name
function c_pf_test {
proof_impl=groth16
ex_name=$1
$BIN examples/C/r1cs/$ex_name.c r1cs --action setup --proof-impl $proof_impl
$ZK_BIN --inputs examples/C/r1cs/$ex_name.c.pin --action prove --proof-impl $proof_impl
$ZK_BIN --inputs examples/C/r1cs/$ex_name.c.vin --action verify --proof-impl $proof_impl
rm -rf P V pi
}
c_pf_test add

View File

@@ -17,6 +17,12 @@ fn arr_val_to_tup(v: &Value) -> Value {
}
vec
}),
Value::Tuple(vs) => Value::Tuple(
vs.iter()
.map(arr_val_to_tup)
.collect::<Vec<_>>()
.into_boxed_slice(),
),
v => v.clone(),
}
}
@@ -29,7 +35,7 @@ impl RewritePass for Linearizer {
rewritten_children: F,
) -> Option<Term> {
match &orig.op() {
Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))),
Op::Const(v) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))),
Op::Var(name, Sort::Array(..)) => {
let precomp = extras::array_to_tuple(orig);
let new_name = format!("{name}.tup");

View File

@@ -13,11 +13,14 @@
//!
//! So, essentially, what's going on is that T maps each term t to an (approximate) analysis of t
//! that indicates which accesses can be perfectly resolved.
//!
//! We could make the analysis more precise (and/or efficient) with a better data structure for
//! tracking information about value locations.
use crate::ir::term::extras::as_uint_constant;
use crate::ir::term::*;
use log::{debug, trace};
use log::trace;
#[derive(Default)]
struct OblivRewriter {
@@ -30,6 +33,7 @@ fn suitable_const(t: &Term) -> bool {
}
impl OblivRewriter {
/// Get, prefering tuple if possible.
fn get_t(&self, t: &Term) -> &Term {
self.tups.get(t).unwrap_or(self.terms.get(t).unwrap())
}
@@ -57,7 +61,7 @@ impl OblivRewriter {
(
if let Some(aa) = self.tups.get(a) {
if suitable_const(i) {
debug!("simplify store {}", i);
trace!("simplify store {}", i);
Some(term![Op::Update(get_const(i)); aa.clone(), self.get_t(v).clone()])
} else {
None
@@ -73,7 +77,7 @@ impl OblivRewriter {
let i = &t.cs()[1];
if let Some(aa) = self.tups.get(a) {
if suitable_const(i) {
debug!("simplify select {}", i);
trace!("simplify select {}", i);
let tt = term![Op::Field(get_const(i)); aa.clone()];
(
Some(tt.clone()),
@@ -115,7 +119,37 @@ impl OblivRewriter {
},
)
}
Op::Tuple => panic!("Tuple in obliv"),
Op::Tuple => (
if t.cs().iter().all(|c| self.tups.contains_key(c)) {
Some(term(
Op::Tuple,
t.cs()
.iter()
.map(|c| self.tups.get(c).unwrap().clone())
.collect(),
))
} else {
None
},
None,
),
Op::Field(i) => (
if t.cs().iter().all(|c| self.tups.contains_key(c)) {
Some(term_c![Op::Field(*i); self.get_t(&t.cs()[0])])
} else {
None
},
None,
),
Op::Update(i) => (
if t.cs().iter().all(|c| self.tups.contains_key(c)) {
Some(term_c![Op::Update(*i); self.get_t(&t.cs()[0]), self.get_t(&t.cs()[1])])
} else {
None
},
None,
),
//Op::Tuple => panic!("Tuple in obliv"),
_ => (None, None),
};
if let Some(tup) = tup_opt {

View File

@@ -26,6 +26,8 @@ pub mod volatile;
struct Access {
/// The value read or (conditionally) written.
pub val: Term,
/// A (field) hash of the value read or (conditionally) written.
pub val_hash: Option<Term>,
/// The index/address.
pub idx: Term,
/// The time of this access.
@@ -108,19 +110,40 @@ impl AccessCfg {
true,
)
}
fn len(&self) -> usize {
if self.create {
6
} else {
5
fn val_sort_len(s: &Sort) -> usize {
match s {
Sort::Tuple(t) => t.iter().map(Self::val_sort_len).sum(),
Sort::Array(_, v, size) => *size * Self::val_sort_len(v),
_ => 1,
}
}
fn len(&self, s: &Sort) -> usize {
(if self.create { 5 } else { 4 }) + Self::val_sort_len(s)
}
fn bool2pf(&self, t: Term) -> Term {
term![Op::Ite; t, self.one.clone(), self.zero.clone()]
}
fn pf_neg(&self, t: Term) -> Term {
term![PF_ADD; self.one.clone(), term![PF_NEG; t]]
}
fn pf_lit(&self, i: usize) -> Term {
pf_lit(self.field.new_v(i))
}
}
fn scalar_to_field(scalar: &Term, c: &AccessCfg) -> Term {
match check(scalar) {
Sort::Field(f) => {
if f == c.field {
scalar.clone()
} else {
panic!("Cannot convert scalar of field {} to field {}", f, c.field)
}
}
Sort::Bool => c.bool2pf(scalar.clone()),
Sort::BitVector(_) => term![Op::UbvToPf(c.field.clone()); scalar.clone()],
s => panic!("non-scalar sort {}", s),
}
}
/// A bit encoded in the field.
@@ -163,6 +186,7 @@ impl Access {
fn new_read(f: &AccessCfg, idx: Term, val: Term, time: Term) -> Self {
Self {
val,
val_hash: None,
idx,
time,
write: FieldBit::from_bool_lit(f, false),
@@ -173,6 +197,7 @@ impl Access {
fn new_write(f: &AccessCfg, idx: Term, val: Term, active: Term, time: Term) -> Self {
Self {
val,
val_hash: None,
idx,
time,
write: FieldBit::from_bool_lit(f, true),
@@ -183,6 +208,7 @@ impl Access {
fn new_init(f: &AccessCfg, idx: Term, val: Term) -> Self {
Self {
val,
val_hash: None,
idx,
time: f.zero.clone(),
write: FieldBit::from_bool_lit(f, true),
@@ -191,56 +217,127 @@ impl Access {
}
}
fn field_names(c: &AccessCfg, order: Order) -> &'static [&'static str] {
fn field_names(c: &AccessCfg, sort: &Sort, order: Order) -> Vec<String> {
let mut out = Vec::new();
match order {
Order::Hash => {
Self::sort_subnames(sort, "v", &mut out);
out.push("i".into());
out.push("t".into());
out.push("w".into());
out.push("a".into());
if c.create {
&["v", "i", "t", "w", "a", "c"]
} else {
&["v", "i", "t", "w", "a"]
out.push("c".into());
}
}
// dead code, but for clarity...
Order::Sort => {
out.push("i".into());
out.push("t".into());
if c.create {
&["i", "t", "c", "v", "w", "a"]
} else {
&["i", "t", "v", "w", "a"]
out.push("c".into());
}
Self::sort_subnames(sort, "v", &mut out);
out.push("w".into());
out.push("a".into());
}
}
out
}
fn sort_subnames(sort: &Sort, prefix: &str, out: &mut Vec<String>) {
match sort {
Sort::Field(_) | Sort::Bool | Sort::BitVector(_) => out.push(prefix.into()),
Sort::Tuple(ss) => {
for (i, s) in ss.iter().enumerate() {
Self::sort_subnames(s, &format!("{}_{}", prefix, i), out);
}
}
Sort::Array(_, v, size) => {
for i in 0..*size {
Self::sort_subnames(v, &format!("{}_{}", prefix, i), out);
}
}
_ => unreachable!(),
}
}
fn val_to_field_elements(val: &Term, c: &AccessCfg, out: &mut Vec<Term>) {
match check(val) {
Sort::Field(_) | Sort::Bool | Sort::BitVector(_) => out.push(scalar_to_field(val, c)),
Sort::Tuple(ss) => {
for i in 0..ss.len() {
Self::val_to_field_elements(&term![Op::Field(i); val.clone()], c, out);
}
}
Sort::Array(_, _, size) => {
for i in 0..size {
Self::val_to_field_elements(
&term![Op::Select; val.clone(), c.pf_lit(i)],
c,
out,
);
}
}
_ => unreachable!(),
}
}
fn val_from_field_elements_trusted(sort: &Sort, next: &mut impl FnMut() -> Term) -> Term {
match sort {
Sort::Field(_) => next().clone(),
Sort::Bool => term![Op::PfToBoolTrusted; next().clone()],
Sort::BitVector(w) => term![Op::PfToBv(*w); next().clone()],
Sort::Tuple(ss) => term(
Op::Tuple,
ss.iter()
.map(|s| Self::val_from_field_elements_trusted(s, next))
.collect(),
),
Sort::Array(k, v, size) => term(
Op::Array(*k.clone(), *v.clone()),
(0..*size)
.map(|_| Self::val_from_field_elements_trusted(v, next))
.collect(),
),
_ => unreachable!(),
}
}
fn to_field_elems(&self, c: &AccessCfg, order: Order) -> Vec<Term> {
let mut out = Vec::new();
match order {
Order::Hash => {
let mut out = vec![
self.val.clone(),
self.idx.clone(),
self.time.clone(),
self.write.f.clone(),
self.active.f.clone(),
];
if c.create {
out.push(self.create.f.clone())
}
out
}
Order::Sort => {
let mut out = vec![self.idx.clone(), self.time.clone()];
if c.create {
out.push(self.create.f.clone())
}
out.push(self.val.clone());
Self::val_to_field_elements(&self.val, c, &mut out);
out.push(self.idx.clone());
out.push(self.time.clone());
out.push(self.write.f.clone());
out.push(self.active.f.clone());
if c.create {
out.push(self.create.f.clone())
}
}
Order::Sort => {
out.push(self.idx.clone());
out.push(self.time.clone());
if c.create {
out.push(self.create.f.clone())
}
Self::val_to_field_elements(&self.val, c, &mut out);
out.push(self.write.f.clone());
out.push(self.active.f.clone());
out
}
}
out
}
fn from_field_elems_trusted(elems: Vec<Term>, c: &AccessCfg, order: Order) -> Self {
debug_assert_eq!(elems.len(), c.len());
fn from_field_elems_trusted(
elems: Vec<Term>,
val_sort: &Sort,
c: &AccessCfg,
order: Order,
) -> Self {
debug_assert_eq!(elems.len(), c.len(val_sort));
let mut elems = elems.into_iter();
let mut next = || {
let t = elems.next().unwrap();
@@ -249,7 +346,8 @@ impl Access {
};
match order {
Order::Hash => Self {
val: next(),
val: Self::val_from_field_elements_trusted(val_sort, &mut next),
val_hash: None,
idx: next(),
time: next(),
write: FieldBit::from_trusted_field(c, next()),
@@ -261,6 +359,7 @@ impl Access {
},
},
Order::Sort => Self {
val_hash: None,
idx: next(),
time: next(),
create: if c.create {
@@ -268,16 +367,26 @@ impl Access {
} else {
FieldBit::from_bool_lit(c, false)
},
val: next(),
val: Self::val_from_field_elements_trusted(val_sort, &mut next),
write: FieldBit::from_trusted_field(c, next()),
active: FieldBit::from_trusted_field(c, next()),
},
}
}
fn universal_hash(&self, c: &AccessCfg, hasher: &hash::UniversalHasher) -> Term {
assert_eq!(hasher.len(), c.len());
hasher.hash(self.to_field_elems(c, Order::Hash))
fn universal_hash(
&self,
c: &AccessCfg,
val_sort: &Sort,
hasher: &hash::UniversalHasher,
) -> (Term, Term) {
assert_eq!(hasher.len(), c.len(val_sort));
let mut val_elems = Vec::new();
Self::val_to_field_elements(&self.val, c, &mut val_elems);
(
hasher.hash(self.to_field_elems(c, Order::Hash)),
hasher.hash(val_elems),
)
}
fn to_field_tuple(&self, c: &AccessCfg) -> Term {
@@ -287,16 +396,17 @@ impl Access {
fn declare_trusted(
c: &AccessCfg,
mut declare_var: impl FnMut(&str, Term) -> Term,
val_sort: &Sort,
value_tuple: Term,
) -> Self {
let mut declare_field =
|name: &str, idx: usize| declare_var(name, term![Op::Field(idx); value_tuple.clone()]);
let elems = Self::field_names(c, Order::Sort)
let elems = Self::field_names(c, val_sort, Order::Sort)
.iter()
.enumerate()
.map(|(idx, name)| declare_field(name, idx))
.collect();
Self::from_field_elems_trusted(elems, c, Order::Sort)
Self::from_field_elems_trusted(elems, val_sort, c, Order::Sort)
}
}
@@ -318,8 +428,10 @@ pub struct Ram {
boundary_conditions: BoundaryConditions,
/// The unique id of this RAM
id: usize,
/// The sort for times, indices, and values.
/// The sort for times and indices.
sort: Sort,
/// The sort for values.
val_sort: Sort,
/// The size
size: usize,
/// The list of accesses (in access order)
@@ -332,11 +444,25 @@ pub struct Ram {
cfg: AccessCfg,
}
#[allow(dead_code)]
/// Are terms of sort `s` hashable using a UHF keyed by field type `f`.
fn hashable(s: &Sort, f: &FieldT) -> bool {
match s {
Sort::Field(ff) => f == ff,
Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)),
Sort::BitVector(_) => true,
Sort::Bool => true,
Sort::Array(_k, v, size) => *size < 20 && hashable(v, f),
_ => false,
}
}
impl Ram {
fn new(
id: usize,
size: usize,
cfg: AccessCfg,
val_sort: Sort,
boundary_conditions: BoundaryConditions,
) -> Self {
assert!(!matches!(
@@ -347,6 +473,7 @@ impl Ram {
boundary_conditions,
id,
sort: Sort::Field(cfg.field.clone()),
val_sort,
cfg,
accesses: Default::default(),
size,
@@ -369,6 +496,21 @@ impl Ram {
}
}
}
#[track_caller]
#[allow(unused_variables)]
/// Assert that `other` is hashable using the field of `self`.
fn assert_hashable(&self, other: &Term) {
#[cfg(debug_assertions)]
{
let s = check(other);
if !hashable(&s, &self.cfg.field) {
panic!(
"RAM field of sort {} is not hashable with field {}",
s, self.cfg.field
);
}
}
}
fn next_time_term(&mut self) -> Term {
let t = self.sort.nth_elem(self.next_time);
if !self.end_of_time {
@@ -379,12 +521,12 @@ impl Ram {
fn new_read(&mut self, idx: Term, computation: &mut Computation, read_value: Term) -> Term {
let val_name = format!("__ram{}_read_v{}", self.id, self.accesses.len());
self.assert_field(&idx);
self.assert_field(&read_value);
self.assert_hashable(&read_value);
debug_assert!(!self.end_of_time);
let var = computation.new_var(
&val_name,
self.sort.clone(),
self.val_sort.clone(),
Some(crate::ir::proof::PROVER_ID),
Some(read_value),
);
@@ -396,7 +538,7 @@ impl Ram {
}
fn new_final_read(&mut self, idx: Term, val: Term) {
self.assert_field(&idx);
self.assert_field(&val);
self.assert_hashable(&val);
self.end_of_time = true;
let time = self.next_time_term();
trace!(
@@ -411,7 +553,7 @@ impl Ram {
fn new_write(&mut self, idx: Term, val: Term, guard: Term) {
debug_assert!(!self.end_of_time);
self.assert_field(&idx);
self.assert_field(&val);
self.assert_hashable(&val);
debug_assert_eq!(&check(&guard), &Sort::Bool);
let time = self.next_time_term();
trace!(
@@ -426,7 +568,7 @@ impl Ram {
}
fn new_init(&mut self, idx: Term, val: Term) {
self.assert_field(&idx);
self.assert_field(&val);
self.assert_hashable(&val);
self.end_of_time = true;
trace!("init: ops: idx {}, val {}", idx.op(), val.op());
self.accesses

View File

@@ -25,6 +25,7 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
let id = ram.id;
let ns = Namespace::new().subspace(&format!("ram{id}"));
let f_s = Sort::Field(f.clone());
let v_s = ram.val_sort.clone();
let mut new_var =
|name: &str, val: Term| c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val));
@@ -33,9 +34,16 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
let sorted_accesses = if ram.cfg.waksman {
let mut new_bit_var =
|name: &str, val: Term| c.new_var(&ns.fqn(name), Sort::Bool, PROVER_VIS, Some(val));
permutation::waksman(&ram.accesses, &ram.cfg, &mut new_bit_var)
permutation::waksman(&ram.accesses, &ram.cfg, &v_s, &mut new_bit_var)
} else {
permutation::msh(&ram.accesses, &ns, &ram.cfg, &mut new_var, &mut assertions)
permutation::msh(
&ram.accesses,
&ns,
&ram.cfg,
&mut new_var,
&v_s,
&mut assertions,
)
};
// (2) check the sorted transcript
@@ -64,22 +72,21 @@ pub fn check_ram(c: &mut Computation, ram: Ram) {
}
let mut deltas = Vec::new();
// To: check some condition on the start?
for j in 0..(n - 1) {
// previous entry
let i = &accs[j].idx;
let t = &accs[j].time;
let v = &accs[j].val;
let v = accs[j].val_hash.as_ref().expect("missing value hash");
// this entry
let i_n = &accs[j + 1].idx;
let t_n = &accs[j + 1].time;
let v_n = &accs[j + 1].val;
let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash");
let c_n = &accs[j + 1].create;
let w_n = &accs[j + 1].write;
let v_p = if only_init {
v.clone()
} else if j == 0 {
default.clone()
} else {
term![ITE; c_n.b.clone(), default.clone(), v.clone()]
};

View File

@@ -8,6 +8,7 @@ use std::collections::VecDeque;
pub(super) fn waksman(
accesses: &VecDeque<Access>,
cfg: &AccessCfg,
val_sort: &Sort,
new_var: &mut impl FnMut(&str, Term) -> Term,
) -> Vec<Access> {
let f = &cfg.field;
@@ -37,7 +38,14 @@ pub(super) fn waksman(
let elems = (0..len)
.map(|idx| term![Op::Field(idx); v.clone()])
.collect();
Access::from_field_elems_trusted(elems, cfg, Order::Sort)
let mut access = Access::from_field_elems_trusted(elems, val_sort, cfg, Order::Sort);
assert!(
check(&access.val).is_scalar(),
"Waksman only supports scalar values; got {}",
check(&access.val)
);
access.val_hash = Some(super::scalar_to_field(&access.val, cfg));
access
})
.collect();
sorted_accesses
@@ -57,6 +65,7 @@ pub(super) fn msh(
ns: &Namespace,
cfg: &AccessCfg,
new_var: &mut impl FnMut(&str, Term) -> Term,
val_sort: &Sort,
assertions: &mut Vec<Term>,
) -> Vec<Access> {
let f = &cfg.field;
@@ -66,13 +75,14 @@ pub(super) fn msh(
let sorted_field_tuple_values: Vec<Term> = unmake_array(
term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())],
);
let sorted_accesses: Vec<Access> = sorted_field_tuple_values
let mut sorted_accesses: Vec<Access> = sorted_field_tuple_values
.into_iter()
.enumerate()
.map(|(i, v)| {
Access::declare_trusted(
cfg,
|name: &str, term: Term| new_var(&format!("sort_a{i}_{name}"), term),
val_sort,
v,
)
})
@@ -81,20 +91,23 @@ pub(super) fn msh(
.into_iter()
.chain(sorted_accesses.iter().map(|a| a.to_field_tuple(cfg)))
.collect();
let uhf = UniversalHasher::new(ns.fqn("uhf_key"), f, uhf_inputs.clone(), cfg.len());
let uhf = UniversalHasher::new(ns.fqn("uhf_key"), f, uhf_inputs.clone(), cfg.len(val_sort));
let msh = MsHasher::new(ns.fqn("ms_hash_key"), f, uhf_inputs);
// (2) permutation argument
let univ_hashes_unsorted: Vec<Term> = accesses
.iter()
.map(|a| a.universal_hash(cfg, &uhf))
.map(|a| a.universal_hash(cfg, val_sort, &uhf).0)
.collect();
let univ_hashes_sorted: Vec<Term> = sorted_accesses
let (univ_hashes_sorted, val_hashes_sorted): (Vec<Term>, Vec<Term>) = sorted_accesses
.iter()
.map(|a| a.universal_hash(cfg, &uhf))
.collect();
.map(|a| a.universal_hash(cfg, val_sort, &uhf))
.unzip();
let ms_hash_passes = term![EQ; msh.hash(univ_hashes_unsorted), msh.hash(univ_hashes_sorted)];
assertions.push(ms_hash_passes);
for (access, hash) in sorted_accesses.iter_mut().zip(val_hashes_sorted) {
access.val_hash = Some(hash);
}
sorted_accesses
}

View File

@@ -55,7 +55,13 @@ pub fn persistent_to_ram(c: &mut Computation, cfg: &AccessCfg) -> Vec<Ram> {
c.metadata.add_commitment(final_names);
let boundary_conditions = BoundaryConditions::Persistent(terms, final_terms);
let ram = Ram::new(i, size, cfg.clone(), boundary_conditions);
let ram = Ram::new(
i,
size,
cfg.clone(),
Sort::Field(cfg.field.clone()),
boundary_conditions,
);
term_rams.insert(init_term, i);
rams.push(ram);

View File

@@ -34,12 +34,24 @@ struct ArrayGraph {
ram_terms: TermSet,
}
/// Are terms of sort `s` hashable using a UHF keyed by field type `f`.
fn hashable(s: &Sort, f: &FieldT) -> bool {
match s {
Sort::Field(ff) => f == ff,
Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)),
Sort::BitVector(_) => true,
Sort::Bool => true,
Sort::Array(_k, v, size) => *size < 20 && hashable(v, f),
_ => false,
}
}
/// Does this array have a sort compatible with our RAM machinery?
fn right_sort(t: &Term, f: &FieldT) -> bool {
let s = check(t);
if let Sort::Array(k, v, _) = &s {
if let (Sort::Field(k), Sort::Field(v)) = (&**k, &**v) {
v == f && k == f
if let Sort::Field(k) = &**k {
k == f && hashable(v, f)
} else {
false
}
@@ -168,9 +180,9 @@ impl Extactor {
// create a default RAM from `t`'s sort.
let id = self.rams.len();
let t_sort = check(t);
let (key_sort, _, size) = t_sort.as_array();
let (key_sort, val_sort, size) = t_sort.as_array();
let def = BoundaryConditions::Default(key_sort.default_term());
let mut ram = Ram::new(id, size, self.cfg.clone(), def);
let mut ram = Ram::new(id, size, self.cfg.clone(), val_sort.clone(), def);
// update with details specific to `t`.
match &t.op() {

View File

@@ -224,9 +224,13 @@ fn untuple_value(v: &Value) -> Value {
}
}
fn find_tuple_term(t: Term) -> Option<Term> {
PostOrderIter::new(t).find(|c| matches!(check(c), Sort::Tuple(..)))
}
#[allow(dead_code)]
fn tuple_free(t: Term) -> bool {
PostOrderIter::new(t).all(|c| !matches!(check(&c), Sort::Tuple(..)))
find_tuple_term(t).is_none()
}
/// Run the tuple elimination pass.
@@ -301,6 +305,8 @@ pub fn eliminate_tuples(cs: &mut Computation) {
.collect();
#[cfg(debug_assertions)]
for o in &cs.outputs {
assert!(tuple_free(o.clone()));
if let Some(t) = find_tuple_term(o.clone()) {
panic!("Tuple term {}", t)
}
}
}

View File

@@ -2102,7 +2102,7 @@ impl Computation {
);
self.metadata.new_input(name.to_owned(), party, s.clone());
if let Some(p) = precompute {
assert_eq!(&s, &check(&p));
assert_eq!(&s, &check(&p), "precompute {} doesn't match sort {}", p, s);
self.precomputes.add_output(name.to_owned(), p);
}
leaf_term(Op::Var(name.to_owned(), s))