mirror of
https://github.com/circify/circ.git
synced 2026-01-09 13:48:02 -05:00
Improve RAM: oblivious & volatile (#170)
* Improve the oblivious RAM pass by killing the hack where we treat selects as arrays. * Fix a bug where the volatile RAM pass would not place selects before stores against the same array * Improve that volatile RAM pass by placing selects against the same array literal in the same RAM. Before, they would each end up in different RAMs, which sucks. This is especially bad for ROMs.
This commit is contained in:
@@ -3,228 +3,148 @@
|
||||
//! This module attempts to identify *oblivious* arrays: those that are only accessed at constant
|
||||
//! indices. These arrays can be replaced with tuples. Then, a tuple elimination pass can be run.
|
||||
//!
|
||||
//! It operates in two passes:
|
||||
//! It operates in a single IO (inputs->outputs) pass, that computes two maps:
|
||||
//!
|
||||
//! 1. determine which arrays are oblivious
|
||||
//! 2. replace oblivious arrays with tuples
|
||||
//!
|
||||
//!
|
||||
//! ## Pass 1: Identifying oblivious arrays
|
||||
//!
|
||||
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation
|
||||
//! system, performing the following inferences:
|
||||
//!
|
||||
//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious;
|
||||
//! * If `a[i]`, `a` and `a[i]` are equi-oblivious
|
||||
//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious
|
||||
//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious
|
||||
//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious
|
||||
//! * If `a=b`, then `a` and `b` are equi-oblivious
|
||||
//!
|
||||
//! This procedure is iterated to fixpoint.
|
||||
//!
|
||||
//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select
|
||||
//! terms as non-oblivious. This makes it easy to see which selects should be replaced later.
|
||||
//!
|
||||
//! ### Sharing & Constant Arrays
|
||||
//!
|
||||
//! This pass is effective given the somewhat naive assumption that array terms in the term graph
|
||||
//! can be separated into different "threads", which are not connected. Sometimes they are,
|
||||
//! especially by constant arrays.
|
||||
//!
|
||||
//! For example, consider code like this:
|
||||
//!
|
||||
//! ```ignore
|
||||
//! x = [0, 0, 0, 0]
|
||||
//! y = [0, 0, 0, 0]
|
||||
//! // oblivious modifications to x
|
||||
//! // non-oblivious modifications to y
|
||||
//! ```
|
||||
//!
|
||||
//! In this situation, we would hope that x and its derived arrays will be identified as
|
||||
//! "oblivious" while y will not.
|
||||
//!
|
||||
//! However, because of term sharing, the constant array [0,0,0,0] happens to be the root of both
|
||||
//! x's and y's store chains. If the constant array is `c`, then the first store to x might be
|
||||
//! `c[0\v1]` while the first store to y might be `c[i2\v2]`. The "store" rule for non-oblivious
|
||||
//! analysis would say that `c` is non-oblivious (b/c of the second store) and therefore the whole
|
||||
//! x store chain would b too...
|
||||
//!
|
||||
//! The problem isn't just with constants. If any non-oblivious stores branch off an otherwise
|
||||
//! oblivious store chain, the same thing happens.
|
||||
//!
|
||||
//! Since constants are a pervasive problem, we special-case them, omitting them from the analysis.
|
||||
//!
|
||||
//! We probably want a better idea of what this pass does (and how to handle arrays) at some
|
||||
//! point...
|
||||
//!
|
||||
//! ## Pass 2: Replacing oblivious arrays with term lists.
|
||||
//!
|
||||
//! In this pass, the goal is to
|
||||
//!
|
||||
//! * map array terms to tuple terms
|
||||
//! * map array selections to tuple field gets
|
||||
//!
|
||||
//! In both cases we look at the non-oblivious array/select set to see whether to do the
|
||||
//! replacement.
|
||||
//! * `R`: the rewrite map; keys map to values of the same sort. This is the canonical rewrite map.
|
||||
//! * `T`: the map from a term, to one whose sort has arrays replaced with tuples at the top of the sort tree
|
||||
//! * if some select has a constant index and is against an entry of T, then:
|
||||
//! * we add ((field i) T_ENTRY) to T for that select
|
||||
//! * if the above has scalar sort, we add it to R
|
||||
//!
|
||||
//! So, essentially, what's going on is that T maps each term t to an (approximate) analysis of t
|
||||
//! that indicates which accesses can be perfectly resolved.
|
||||
|
||||
use super::super::visit::*;
|
||||
use crate::ir::term::extras::as_uint_constant;
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::debug;
|
||||
use log::{debug, trace};
|
||||
|
||||
struct NonOblivComputer {
|
||||
not_obliv: TermSet,
|
||||
#[derive(Default)]
|
||||
struct OblivRewriter {
|
||||
tups: TermMap<Term>,
|
||||
terms: TermMap<Term>,
|
||||
}
|
||||
|
||||
impl NonOblivComputer {
|
||||
fn mark(&mut self, a: &Term) -> bool {
|
||||
if !a.is_const() && self.not_obliv.insert(a.clone()) {
|
||||
debug!("Not obliv: {}", a);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool {
|
||||
if !a.is_const() && !b.is_const() {
|
||||
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
|
||||
(false, true) => {
|
||||
self.not_obliv.insert(a.clone());
|
||||
true
|
||||
}
|
||||
(true, false) => {
|
||||
self.not_obliv.insert(b.clone());
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
not_obliv: TermSet::default(),
|
||||
}
|
||||
}
|
||||
fn suitable_const(t: &Term) -> bool {
|
||||
t.is_const() && matches!(check(t), Sort::BitVector(_) | Sort::Field(_) | Sort::Bool)
|
||||
}
|
||||
|
||||
impl ProgressAnalysisPass for NonOblivComputer {
|
||||
fn visit(&mut self, term: &Term) -> bool {
|
||||
match &term.op() {
|
||||
Op::Store | Op::CStore => {
|
||||
let a = &term.cs()[0];
|
||||
let i = &term.cs()[1];
|
||||
let v = &term.cs()[2];
|
||||
let mut progress = false;
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
if let Op::Const(_) = i.op() {
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
} else {
|
||||
progress = self.mark(a) || progress;
|
||||
progress = self.mark(term) || progress;
|
||||
}
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Array(..) => {
|
||||
let mut progress = false;
|
||||
if !term.cs().is_empty() {
|
||||
if let Sort::Array(..) = check(&term.cs()[0]) {
|
||||
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
|
||||
for i in 0..term.cs().len() - 1 {
|
||||
progress =
|
||||
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
|
||||
impl OblivRewriter {
|
||||
fn get_t(&self, t: &Term) -> &Term {
|
||||
self.tups.get(t).unwrap_or(self.terms.get(t).unwrap())
|
||||
}
|
||||
fn get(&self, t: &Term) -> &Term {
|
||||
self.terms.get(t).unwrap()
|
||||
}
|
||||
fn visit(&mut self, t: &Term) {
|
||||
let (tup_opt, term_opt) = match t.op() {
|
||||
Op::Const(v @ Value::Array(_)) => (Some(leaf_term(Op::Const(arr_val_to_tup(v)))), None),
|
||||
Op::Array(_k, _v) => (
|
||||
Some(term(
|
||||
Op::Tuple,
|
||||
t.cs().iter().map(|c| self.get_t(c)).cloned().collect(),
|
||||
)),
|
||||
None,
|
||||
),
|
||||
Op::Fill(_k, size) => (
|
||||
Some(term(Op::Tuple, vec![self.get_t(&t.cs()[0]).clone(); *size])),
|
||||
None,
|
||||
),
|
||||
Op::Store => {
|
||||
let a = &t.cs()[0];
|
||||
let i = &t.cs()[1];
|
||||
let v = &t.cs()[2];
|
||||
(
|
||||
if let Some(aa) = self.tups.get(a) {
|
||||
if suitable_const(i) {
|
||||
debug!("simplify store {}", i);
|
||||
Some(term![Op::Update(get_const(i)); aa.clone(), self.get_t(v).clone()])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
for i in (0..term.cs().len() - 1).rev() {
|
||||
progress =
|
||||
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
|
||||
}
|
||||
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
|
||||
}
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Fill(..) => {
|
||||
let v = &term.cs()[0];
|
||||
if let Sort::Array(..) = check(v) {
|
||||
self.bi_implicate(term, &term.cs()[0])
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
None
|
||||
},
|
||||
None,
|
||||
)
|
||||
}
|
||||
Op::Select => {
|
||||
// Even though the selected value may not have array sort, we still flag it as
|
||||
// non-oblivious so we know whether to replace it or not.
|
||||
let a = &term.cs()[0];
|
||||
let i = &term.cs()[1];
|
||||
let mut progress = false;
|
||||
if let Op::Const(_) = i.op() {
|
||||
// pass
|
||||
let a = &t.cs()[0];
|
||||
let i = &t.cs()[1];
|
||||
if let Some(aa) = self.tups.get(a) {
|
||||
if suitable_const(i) {
|
||||
debug!("simplify select {}", i);
|
||||
let tt = term![Op::Field(get_const(i)); aa.clone()];
|
||||
(
|
||||
Some(tt.clone()),
|
||||
if check(&tt).is_scalar() {
|
||||
Some(tt)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
}
|
||||
} else {
|
||||
progress = self.mark(a) || progress;
|
||||
progress = self.mark(term) || progress;
|
||||
}
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
progress
|
||||
}
|
||||
Op::Var(..) => {
|
||||
if let Sort::Array(..) = check(term) {
|
||||
self.mark(term)
|
||||
} else {
|
||||
false
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
let t = &term.cs()[1];
|
||||
let f = &term.cs()[2];
|
||||
if let Sort::Array(..) = check(t) {
|
||||
let mut progress = self.bi_implicate(term, t);
|
||||
progress = self.bi_implicate(t, f) || progress;
|
||||
progress = self.bi_implicate(term, f) || progress;
|
||||
progress
|
||||
} else {
|
||||
false
|
||||
}
|
||||
let cond = &t.cs()[0];
|
||||
let case_t = &t.cs()[1];
|
||||
let case_f = &t.cs()[2];
|
||||
(
|
||||
if let (Some(tt), Some(ff)) = (self.tups.get(case_t), self.tups.get(case_f)) {
|
||||
Some(term![Op::Ite; self.get(cond).clone(), tt.clone(), ff.clone()])
|
||||
} else {
|
||||
None
|
||||
},
|
||||
None,
|
||||
)
|
||||
}
|
||||
Op::Eq => {
|
||||
let a = &term.cs()[0];
|
||||
let b = &term.cs()[1];
|
||||
if let Sort::Array(..) = check(a) {
|
||||
self.bi_implicate(a, b)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
let a = &t.cs()[0];
|
||||
let b = &t.cs()[1];
|
||||
(
|
||||
None,
|
||||
if let (Some(aa), Some(bb)) = (self.tups.get(a), self.tups.get(b)) {
|
||||
Some(term![Op::Eq; aa.clone(), bb.clone()])
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
}
|
||||
Op::Tuple => {
|
||||
panic!("Tuple in obliv")
|
||||
}
|
||||
_ => false,
|
||||
Op::Tuple => panic!("Tuple in obliv"),
|
||||
_ => (None, None),
|
||||
};
|
||||
if let Some(tup) = tup_opt {
|
||||
trace!("Tuple rw: \n{}\nto\n{}", t, tup);
|
||||
self.tups.insert(t.clone(), tup);
|
||||
}
|
||||
let new_t = term_opt.unwrap_or_else(|| {
|
||||
term(
|
||||
t.op().clone(),
|
||||
t.cs().iter().map(|c| self.get(c)).cloned().collect(),
|
||||
)
|
||||
});
|
||||
|
||||
self.terms.insert(t.clone(), new_t);
|
||||
}
|
||||
}
|
||||
|
||||
struct Replacer {
|
||||
/// The maximum size of arrays that will be replaced.
|
||||
not_obliv: TermSet,
|
||||
}
|
||||
|
||||
impl Replacer {
|
||||
fn should_replace(&self, a: &Term) -> bool {
|
||||
!self.not_obliv.contains(a)
|
||||
/// Eliminate oblivious arrays. See module documentation.
|
||||
pub fn elim_obliv(c: &mut Computation) {
|
||||
let mut pass = OblivRewriter::default();
|
||||
for t in c.terms_postorder() {
|
||||
pass.visit(&t);
|
||||
}
|
||||
for o in &mut c.outputs {
|
||||
debug_assert!(check(o).is_scalar());
|
||||
*o = pass.get(o).clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn arr_val_to_tup(v: &Value) -> Value {
|
||||
match v {
|
||||
Value::Array(Array {
|
||||
@@ -240,13 +160,6 @@ fn arr_val_to_tup(v: &Value) -> Value {
|
||||
}
|
||||
}
|
||||
|
||||
fn term_arr_val_to_tup(a: Term) -> Term {
|
||||
match &a.op() {
|
||||
Op::Const(v @ Value::Array(..)) => leaf_term(Op::Const(arr_val_to_tup(v))),
|
||||
_ => a,
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn get_const(t: &Term) -> usize {
|
||||
as_uint_constant(t)
|
||||
@@ -255,108 +168,6 @@ fn get_const(t: &Term) -> usize {
|
||||
.expect("oversize")
|
||||
}
|
||||
|
||||
impl RewritePass for Replacer {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
//debug!("Visit {}", extras::Letified(orig.clone()));
|
||||
let get_cs = || -> Vec<Term> {
|
||||
rewritten_children()
|
||||
.into_iter()
|
||||
.map(term_arr_val_to_tup)
|
||||
.collect()
|
||||
};
|
||||
match &orig.op() {
|
||||
Op::Var(name, Sort::Array(..)) => {
|
||||
if self.should_replace(orig) {
|
||||
let precomp = extras::array_to_tuple(orig);
|
||||
let new_name = format!("{name}.tup");
|
||||
let new_sort = check(&precomp);
|
||||
computation.extend_precomputation(new_name.clone(), precomp);
|
||||
Some(leaf_term(Op::Var(new_name, new_sort)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Select => {
|
||||
// we mark the selected term as non-obliv...
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 2);
|
||||
let k_const = get_const(&cs.pop().unwrap());
|
||||
Some(term(Op::Field(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Store => {
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
Some(term(Op::Update(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::CStore => {
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 4);
|
||||
let cond = cs.remove(3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
let orig = cs[0].clone();
|
||||
Some(term![ITE; cond, term(Op::Update(k_const), cs), orig])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Array(..) => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Tuple, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Fill(_, size) => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Tuple, vec![get_cs().pop().unwrap(); *size]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Eq => {
|
||||
if self.should_replace(&orig.cs()[0]) {
|
||||
Some(term(Op::Eq, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Eliminate oblivious arrays. See module documentation.
|
||||
pub fn elim_obliv(t: &mut Computation) {
|
||||
let mut prop_pass = NonOblivComputer::new();
|
||||
prop_pass.traverse(t);
|
||||
let mut replace_pass = Replacer {
|
||||
not_obliv: prop_pass.not_obliv,
|
||||
};
|
||||
<Replacer as RewritePass>::traverse_full(&mut replace_pass, t, false, false)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -374,6 +185,12 @@ mod test {
|
||||
true
|
||||
}
|
||||
|
||||
fn count_selects(t: &Term) -> usize {
|
||||
PostOrderIter::new(t.clone())
|
||||
.filter(|t| matches!(t.op(), Op::Select))
|
||||
.count()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn obliv() {
|
||||
let z = term![Op::Const(Value::Array(Array::new(
|
||||
@@ -471,4 +288,263 @@ mod test {
|
||||
assert!(!array_free(&c.outputs[0]));
|
||||
assert!(array_free(&c.outputs[1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_stores_branching_selects() {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs ) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#a (mod 11) #f0 4 ()))
|
||||
(a1 (store a0 #f0 #f1))
|
||||
(x0 (select a1 #f0))
|
||||
(x1 (select a1 #f1))
|
||||
(a2 (store a1 #f0 #f1))
|
||||
(x2 (select a2 #f2))
|
||||
(x3 (select a2 #f3))
|
||||
(a3 (store a2 #f1 #f1))
|
||||
(x4 (select a3 #f0))
|
||||
(x5 (select a3 #f1))
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_stores_branching_selects_partial() {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs (i (mod 11))) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#a (mod 11) #f0 4 ()))
|
||||
(a1 (store a0 #f0 #f1))
|
||||
(x0 (select a1 #f0))
|
||||
(x1 (select a1 #f1))
|
||||
(a2 (store a1 #f0 #f1))
|
||||
(x2 (select a2 #f2))
|
||||
(x3 (select a2 #f3))
|
||||
(a3 (store a2 i #f1))
|
||||
(x4 (select a3 #f0))
|
||||
(x5 (select a3 #f1))
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linear_stores_branching_selects_partial_2() {
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs (i (mod 11))) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#a (mod 11) #f0 4 ()))
|
||||
(a1 (store a0 #f0 #f1))
|
||||
(x0 (select a1 #f0))
|
||||
(x1 (select a1 #f1))
|
||||
(a2 (store a1 i #f1))
|
||||
(x2 (select a2 #f2))
|
||||
(x3 (select a2 #f3))
|
||||
(a3 (store a2 #f0 #f1))
|
||||
(x4 (select a3 #f0))
|
||||
(x5 (select a3 #f1))
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nest_obliv() {
|
||||
env_logger::try_init().ok();
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs (i (mod 11))) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1)))))
|
||||
(a1 (store a0 #f0 (store (select a0 #f0) #f1 #f1)))
|
||||
(x0 (select (select a1 #f0) #f0))
|
||||
(x1 (select (select a1 #f1) #f0))
|
||||
(a2 (store a1 #f1 (store (select a1 #f1) #f1 #f1)))
|
||||
(x2 (select (select a2 #f0) #f1))
|
||||
(x3 (select (select a2 #f1) #f1))
|
||||
(a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1)))
|
||||
(x4 (select (select a3 #f1) #f0))
|
||||
(x5 (select (select a3 #f0) #f1))
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nest_obliv_partial() {
|
||||
env_logger::try_init().ok();
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs (i (mod 11))) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1)))))
|
||||
(a1 (store a0 #f0 (store (select a0 #f0) #f1 #f1)))
|
||||
(x0 (select (select a1 #f0) #f0))
|
||||
(x1 (select (select a1 #f1) #f0))
|
||||
(a2 (store a1 i (store (select a1 i) #f1 #f1))) ; not elim
|
||||
(x2 (select (select a2 #f0) #f1)) ; not elim (2)
|
||||
(x3 (select (select a2 #f1) #f1)) ; not elim (2)
|
||||
(a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1))) ; not elim (dup)
|
||||
(x4 (select (select a3 #f1) #f0)) ; not elim (2)
|
||||
(x5 (select (select a3 #f0) #f1)) ; not elim (2)
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
let before = count_selects(&c.outputs[0]);
|
||||
elim_obliv(&mut c);
|
||||
assert!(count_selects(&c.outputs[0]) < before);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nest_no_obliv() {
|
||||
env_logger::try_init().ok();
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs (i (mod 11))) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#l (mod 11) ((#l (mod 11) (#f1 #f0)) (#l (mod 11) (#f0 #f1)))))
|
||||
(a1 (store a0 i (store (select a0 i) #f1 #f1)))
|
||||
(x0 (select (select a1 #f0) #f0))
|
||||
(x1 (select (select a1 #f1) #f0))
|
||||
(a2 (store a1 #f0 (store (select a1 #f0) #f1 #f1))) ; not elim
|
||||
(x2 (select (select a2 #f0) #f1)) ; not elim (2)
|
||||
(x3 (select (select a2 #f1) #f1)) ; not elim (2)
|
||||
(a3 (store a2 #f1 (store (select a2 #f1) #f0 #f1))) ; not elim (dup)
|
||||
(x4 (select (select a3 #f1) #f0)) ; not elim (2)
|
||||
(x5 (select (select a3 #f0) #f1)) ; not elim (2)
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
let before = count_selects(&c.outputs[0]);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), before);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_array_ptr_chase_eq_size() {
|
||||
env_logger::try_init().ok();
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties )
|
||||
(inputs (x0 (mod 11))
|
||||
(x1 (mod 11))
|
||||
(x2 (mod 11))
|
||||
(x3 (mod 11))
|
||||
(x4 (mod 11))
|
||||
(i0 (mod 11))
|
||||
(i1 (mod 11))
|
||||
(i2 (mod 11))
|
||||
(i3 (mod 11))
|
||||
)
|
||||
(commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(ax (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 x0) #f1 x1) #f2 x2) #f3 x3))
|
||||
(ai (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 i0) #f1 i1) #f2 i2) #f3 i3))
|
||||
(xi0 (select ax (select ai #f0)))
|
||||
(xi1 (select ax (select ai #f1)))
|
||||
(xi2 (select ax (select ai #f2)))
|
||||
(xi3 (select ax (select ai #f3)))
|
||||
)
|
||||
(+ xi0 xi1 xi2 xi3)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_array_ptr_chase_ne_size() {
|
||||
env_logger::try_init().ok();
|
||||
let mut c = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties )
|
||||
(inputs (x0 (mod 11))
|
||||
(x1 (mod 11))
|
||||
(x2 (mod 11))
|
||||
(x3 (mod 11))
|
||||
(x4 (mod 11))
|
||||
(i0 (mod 11))
|
||||
(i1 (mod 11))
|
||||
(i2 (mod 11))
|
||||
)
|
||||
(commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(ax (store (store (store (store (#a (mod 11) #f0 4 ()) #f0 x0) #f1 x1) #f2 x2) #f3 x3))
|
||||
(ai (store (store (store (#a (mod 11) #f0 4 ()) #f0 i0) #f1 i1) #f2 i2))
|
||||
(xi0 (select ax (select ai #f0)))
|
||||
(xi1 (select ax (select ai #f1)))
|
||||
(xi2 (select ax (select ai #f2)))
|
||||
)
|
||||
(+ xi0 xi1 xi2)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
elim_obliv(&mut c);
|
||||
assert_eq!(count_selects(&c.outputs[0]), 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
//! A general-purpose RAM extractor
|
||||
use super::*;
|
||||
|
||||
use fxhash::FxHashMap as HashMap;
|
||||
use fxhash::FxHashSet as HashSet;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
use log::trace;
|
||||
|
||||
/// Graph of the *arrays* in the computation.
|
||||
@@ -198,6 +202,46 @@ impl Extactor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a set of terms, return an ordering of them in post-order, but also with array selects on
|
||||
/// A before stores to A.
|
||||
fn array_order<'a>(terms: HashSet<&'a Term>) -> Vec<&'a Term> {
|
||||
let mut parents: HashMap<&'a Term, HashSet<&'a Term>> = Default::default();
|
||||
let mut children: HashMap<&'a Term, HashSet<&'a Term>> = Default::default();
|
||||
for t in &terms {
|
||||
parents.entry(t).or_default();
|
||||
children.entry(t).or_default();
|
||||
for c in t.cs() {
|
||||
debug_assert!(terms.contains(c));
|
||||
parents.entry(c).or_default().insert(t);
|
||||
children.entry(t).or_default().insert(c);
|
||||
}
|
||||
}
|
||||
let mut output: Vec<&'a Term> = Default::default();
|
||||
// max-heap contains (is_select, term) pairs; so, selects go first.
|
||||
let mut to_output: BinaryHeap<(bool, &'a Term)> = terms
|
||||
.iter()
|
||||
.filter(|t| t.cs().is_empty())
|
||||
.map(|t| (false, *t))
|
||||
.collect();
|
||||
let mut children_not_outputted: HashMap<&'a Term, usize> = children
|
||||
.iter()
|
||||
.map(|(term, children)| (*term, children.len()))
|
||||
.collect();
|
||||
while let Some((_, output_me)) = to_output.pop() {
|
||||
output.push(output_me);
|
||||
for p in parents.get(&output_me).unwrap() {
|
||||
let count = children_not_outputted.get_mut(p).unwrap();
|
||||
assert!(*count > 0);
|
||||
*count -= 1;
|
||||
if *count == 0 {
|
||||
to_output.push((matches!(p.op(), Op::Select), *p));
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(output.len(), terms.len());
|
||||
output
|
||||
}
|
||||
|
||||
impl RewritePass for Extactor {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
@@ -242,9 +286,25 @@ impl RewritePass for Extactor {
|
||||
match &t.op() {
|
||||
// Rewrite select's whose array is a RAM term
|
||||
Op::Select if self.graph.ram_terms.contains(&t.cs()[0]) => {
|
||||
let ram_id = self.get_or_start(&t.cs()[0]);
|
||||
let array = &t.cs()[0];
|
||||
let idx = &t.cs()[1];
|
||||
// If we're based on a leaf
|
||||
let ram_id = if array_leaf(array) {
|
||||
// check if that leaf has a RAM already
|
||||
if let Some(id) = self.term_ram.get(array) {
|
||||
*id
|
||||
} else {
|
||||
let id = self.start_ram_for_leaf(array);
|
||||
|
||||
self.term_ram.insert(array.clone(), id);
|
||||
id
|
||||
}
|
||||
} else {
|
||||
// otherwise, assume that our parent has a RAM already
|
||||
*self.term_ram.get(array).unwrap()
|
||||
};
|
||||
let ram = &mut self.rams[ram_id];
|
||||
let read_value = ram.new_read(t.cs()[1].clone(), computation, t.clone());
|
||||
let read_value = ram.new_read(idx.clone(), computation, t.clone());
|
||||
self.read_terms.insert(t.clone(), read_value.clone());
|
||||
Some(read_value)
|
||||
}
|
||||
@@ -252,6 +312,32 @@ impl RewritePass for Extactor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn traverse(&mut self, computation: &mut Computation) {
|
||||
let terms: Vec<Term> = computation.terms_postorder().collect();
|
||||
let term_refs: HashSet<&Term> = terms.iter().collect();
|
||||
let mut cache = TermMap::<Term>::default();
|
||||
for top in array_order(term_refs) {
|
||||
debug_assert!(!cache.contains_key(top));
|
||||
let new_t_opt = self.visit_cache(computation, top, &cache);
|
||||
let new_t = new_t_opt.unwrap_or_else(|| {
|
||||
term(
|
||||
top.op().clone(),
|
||||
top.cs()
|
||||
.iter()
|
||||
.map(|c| cache.get(c).unwrap())
|
||||
.cloned()
|
||||
.collect(),
|
||||
)
|
||||
});
|
||||
cache.insert(top.clone(), new_t);
|
||||
}
|
||||
computation.outputs = computation
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| cache.get(o).unwrap().clone())
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
/// Find arrays which are RAMs (i.e., accessed with a linear sequences of
|
||||
@@ -517,6 +603,83 @@ mod test {
|
||||
assert_eq!(cs, cs2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rom() {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs ) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(c_array (#a (mod 11) #f0 4 ()))
|
||||
(x0 (select c_array #f0))
|
||||
(x1 (select c_array #f1))
|
||||
(x2 (select c_array #f2))
|
||||
(x3 (select c_array #f3))
|
||||
)
|
||||
(+ x0 x1 x2 x3)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
let mut cs2 = cs.clone();
|
||||
cstore::parse(&mut cs2);
|
||||
let field = FieldT::from(rug::Integer::from(11));
|
||||
let rams = extract(&mut cs2, AccessCfg::default_from_field(field.clone()));
|
||||
extras::assert_all_vars_declared(&cs2);
|
||||
assert_ne!(cs, cs2);
|
||||
assert_eq!(1, rams.len());
|
||||
assert_eq!(4, rams[0].accesses.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_arm_tree() {
|
||||
let cs = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata (parties ) (inputs ) (commitments))
|
||||
(precompute () () (#t ))
|
||||
(set_default_modulus 11
|
||||
(let
|
||||
(
|
||||
(a0 (#a (mod 11) #f0 4 ()))
|
||||
(a1 (store a0 #f0 #f1))
|
||||
(x0 (select a1 #f0))
|
||||
(x1 (select a1 #f1))
|
||||
(a2 (store a1 #f0 #f1))
|
||||
(x2 (select a2 #f2))
|
||||
(x3 (select a2 #f3))
|
||||
(a3 (store a2 #f1 #f1))
|
||||
(x4 (select a3 #f0))
|
||||
(x5 (select a3 #f1))
|
||||
)
|
||||
(+ x0 x1 x2 x3 x4 x5)
|
||||
))
|
||||
)
|
||||
",
|
||||
);
|
||||
let mut cs2 = cs.clone();
|
||||
cstore::parse(&mut cs2);
|
||||
let field = FieldT::from(rug::Integer::from(11));
|
||||
let rams = extract(&mut cs2, AccessCfg::default_from_field(field.clone()));
|
||||
extras::assert_all_vars_declared(&cs2);
|
||||
assert_ne!(cs, cs2);
|
||||
assert_eq!(1, rams.len());
|
||||
assert_eq!(9, rams[0].accesses.len());
|
||||
println!("{:?}", rams[0].accesses);
|
||||
assert_eq!(bool_lit(true), rams[0].accesses[0].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[1].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[2].write.b);
|
||||
assert_eq!(bool_lit(true), rams[0].accesses[3].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[4].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[5].write.b);
|
||||
assert_eq!(bool_lit(true), rams[0].accesses[6].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[7].write.b);
|
||||
assert_eq!(bool_lit(false), rams[0].accesses[8].write.b);
|
||||
}
|
||||
|
||||
#[cfg(feature = "poly")]
|
||||
#[test]
|
||||
fn length_4() {
|
||||
|
||||
@@ -110,6 +110,7 @@ pub fn as_uint_constant(t: &Term) -> Option<Integer> {
|
||||
match &t.op() {
|
||||
Op::Const(Value::BitVector(bv)) => Some(bv.uint().clone()),
|
||||
Op::Const(Value::Field(f)) => Some(f.i()),
|
||||
Op::Const(Value::Bool(b)) => Some((*b).into()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user