From 014604fd6fe8634749e94cb3ac71506ee0e53d45 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Tue, 14 Jun 2022 20:42:04 -0700 Subject: [PATCH 1/3] fix linking pass --- src/front/c/mod.rs | 8 +- src/ir/opt/link.rs | 402 +++++++++++++++++++++++++--------------- src/ir/opt/mem/lin.rs | 4 +- src/ir/opt/mod.rs | 13 +- src/ir/term/mod.rs | 18 +- src/ir/term/text/mod.rs | 12 +- src/ir/term/ty.rs | 6 +- src/target/aby/trans.rs | 4 +- 8 files changed, 284 insertions(+), 183 deletions(-) diff --git a/src/front/c/mod.rs b/src/front/c/mod.rs index ea3bf8a0..08eea009 100644 --- a/src/front/c/mod.rs +++ b/src/front/c/mod.rs @@ -64,8 +64,8 @@ impl FrontEnd for C { // generate new context g.circ = Circify::new(Ct::new()); let call = g.function_queue.pop().unwrap(); - if let Op::Call(name, arg_names, arg_sorts, rets) = &call.op { - g.fn_call(name, arg_names, arg_sorts, rets); + if let Op::Call(name, arg_names, arg_sorts, ret_sorts) = &call.op { + g.fn_call(name, arg_names, arg_sorts, ret_sorts); let comp = g.circ.consume().borrow().clone(); // println!("fn: {}", name); @@ -842,7 +842,7 @@ impl CGen { name.clone(), arg_names.clone(), arg_sorts.clone(), - Sort::Tuple(ret_sorts.clone().into_boxed_slice()), + ret_sorts.clone(), ), arg_terms.clone().into_iter().flatten().collect::>(), ); @@ -1201,7 +1201,7 @@ impl CGen { name: &String, arg_names: &Vec, arg_sorts: &Vec, - rets: &Sort, + ret_sorts: &Vec, ) { debug!("Call: {}", name); println!("Call: {}", name); diff --git a/src/ir/opt/link.rs b/src/ir/opt/link.rs index 30d61d21..ff45333a 100644 --- a/src/ir/opt/link.rs +++ b/src/ir/opt/link.rs @@ -1,167 +1,263 @@ //! Inline function call terms -use std::collections::{BTreeMap, HashMap}; +use fxhash::FxHashMap as HashMap; use crate::ir::term::*; +use crate::ir::opt::visit::RewritePass; -/// Inline cache. -#[derive(Default)] -pub struct Cache(TermMap); - -impl Cache { - /// Empty cache. - pub fn new() -> Self { - Cache(TermMap::new()) - } +/// A recursive inliner. +struct Inliner<'f> { + /// Original source for functions + fs: &'f Functions, + /// Map from names to call-free computations + cache: HashMap, } -// TODO: this can fail if the function name contains '_' -fn get_var_name(name: &String) -> String { - let new_name = name.to_string().replace('.', "_"); - let n = new_name.split('_').collect::>(); - match n.len() { - 5 => n[3].to_string(), - 6.. => { - let l = n.len() - 1; - format!("{}_{}", n[l - 2], n[l]) - } - _ => { - panic!("Invalid variable name: {}", name); +/// Compute the term that corresponds to a function call. +/// +/// ## Arguments +/// +/// * `arg_names`: the argument names, in order +/// * `arg_values`: the argument values, in the same order +/// * `callee`: the called function +/// +/// ## Returns +/// +/// A (tuple) term that corresponds to the output on those inputs +/// +/// ## Note +/// +/// This function **does not** recursively inline. +fn inline_one(arg_names: &Vec, arg_values: Vec, callee: &Computation) -> Term { + let mut sub_map: TermMap = arg_names + .into_iter() + .zip(arg_values) + .map(|(n, v)| { + let s = callee.metadata.input_sort(n).clone(); + (leaf_term(Op::Var(n.clone(), s)), v) + }) + .collect(); + term( + Op::Tuple, + callee + .outputs() + .iter() + .map(|o| extras::substitute_cache(o, &mut sub_map)) + .collect(), + ) +} + +impl<'f> Inliner<'f> { + /// Ensure that a totally inlined version of `name` is in the cache. + fn inline_all(&mut self, name: &str) { + if !self.cache.contains_key(name) { + let mut c = self.fs.get_comp(name).unwrap().clone(); + for t in c.terms_postorder() { + if let Op::Call(callee_name, ..) = &t.op { + self.inline_all(callee_name); + } + } + self.traverse(&mut c); + let present = self.cache.insert(name.into(), c); + assert!(present.is_none()); } } } -fn match_arg(name: &String, params: &BTreeMap) -> Term { - let new_name = get_var_name(name); - params.get(&new_name).unwrap().clone() -} - -fn link(name: &str, params: BTreeMap, fs: &Functions) -> Vec { - let mut res: Vec = Vec::new(); - let comp = fs.computations.get(name).unwrap(); - for o in comp.outputs.iter() { - let mut cache = TermMap::new(); - for t in PostOrderIter::new(o.clone()) { - match &t.op { - Op::Var(name, _) => { - let ret = match_arg(name, ¶ms); - cache.insert(t.clone(), ret.clone()); - } - _ => { - let mut children = Vec::new(); - for c in &t.cs { - if let Some(rewritten_c) = cache.get(c) { - children.push(rewritten_c.clone()); - } else { - children.push(c.clone()); - } - } - cache.insert(t.clone(), term(t.op.clone(), children)); - } - } +/// Rewrites a term, inlining function calls along the way. +/// +/// Assumes that the callees are already inlined. Panics otherwise. +impl<'f> RewritePass for Inliner<'f> { + fn visit Vec>( + &mut self, + _computation: &mut Computation, + orig: &Term, + rewritten_children: F, + ) -> Option { + if let Op::Call(fn_name, arg_names, _, _) = &orig.op { + let callee = self.cache.get(fn_name).expect("missing inlined callee"); + let term = inline_one(arg_names, rewritten_children(), callee); + Some(term) + } else { + None } - res.push(cache.get(o).unwrap().clone()); - } - res -} - -/// Traverse terms and link function calls -pub fn link_function_calls( - term_: Term, - Cache(ref mut rewritten): &mut Cache, - fs: &Functions, -) -> Term { - let mut call_cache: HashMap> = HashMap::new(); - for t in PostOrderIter::new(term_.clone()) { - let mut children = Vec::new(); - for c in &t.cs { - if let Some(rewritten_c) = rewritten.get(c) { - children.push(rewritten_c.clone()); - } else { - children.push(c.clone()); - } - } - let entry = match &t.op { - Op::Field(index) => { - assert!(t.cs.len() > 0); - if let Op::Call(..) = &t.cs[0].op { - if call_cache.contains_key(&t.cs[0]) { - call_cache.get(&t.cs[0]).unwrap()[*index].clone() - } else { - panic!("Fields on a Call term should return"); - } - } else { - term(t.op.clone(), children) - } - } - Op::Call(name, arg_names, arg_sorts, _) => { - println!("Inlining: {}", name); - - // Check number of args - let num_args = arg_sorts.iter().fold(0, |sum, x| { - sum + match x { - Sort::Array(_, _, l) => *l, - _ => 1, - } - }); - assert!( - num_args == t.cs.len(), - "Number of arguments mismatch. {}, {}", - num_args, - t.cs.len() - ); - - // Check arg types - let arg_types = arg_sorts - .iter() - .map(|x| match &x { - Sort::Array(_, val_sort, l) => { - let mut res: Vec = Vec::new(); - for _ in 0..*l { - res.push(*val_sort.clone()); - } - res - } - _ => vec![x.clone()], - }) - .flatten() - .collect::>(); - - assert!( - arg_types == t.cs.iter().map(|c| check(c)).collect::>(), - "Argument type mismatch" - ); - - let mut params: BTreeMap = BTreeMap::new(); - let arg_keys = arg_names - .iter() - .zip(arg_sorts.iter()) - .map(|(n, x)| match &x { - Sort::Array(_, _, l) => { - let mut res: Vec = Vec::new(); - for i in 0..*l { - res.push(format!("{}_{}", n.clone(), i)); - } - res - } - _ => vec![n.clone()], - }) - .flatten(); - for (n, c) in arg_keys.zip(t.cs.clone()) { - params.insert(n.clone(), c.clone()); - } - let res = link(name, params, fs); - call_cache.insert(t.clone(), res.clone()); - res[0].clone() - } - _ => term(t.op.clone(), children), - }; - rewritten.insert(t.clone(), entry); - } - - if let Some(t) = rewritten.get(&term_) { - t.clone() - } else { - panic!("Couldn't find rewritten binarized term: {}", term_); + } +} + +/// Inline all calls within a function set. +pub fn link_all_function_calls(fs: &mut Functions) { + let mut inliner = Inliner { + fs, + cache: Default::default(), + }; + for name in fs.computations.keys() { + inliner.inline_all(name); + } + *fs = Functions { + computations: inliner.cache.into_iter().collect(), + }; +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn bool_arg_nonrec() { + let mut fs = text::parse_functions( + b" + (functions + ( + (myxor + (computation + (metadata () ((a bool) (b bool)) ()) + (xor a b false false) + ) + ) + (main + (computation + (metadata () ((a bool) (b bool)) ()) + (and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b ))) + ) + ) + ) + )", + ); + let expected = text::parse_computation( + b" + (computation + (metadata () ((a bool) (b bool)) ()) + (and false ((field 0) (tuple (xor a b false false)))) + ) + ", + ); + link_all_function_calls(&mut fs); + let c = fs.get_comp("main").unwrap().clone(); + assert_eq!(c, expected); + } + + #[test] + fn scalar_arg_nonrec() { + let mut fs = text::parse_functions( + b" + (functions + ( + (myxor + (computation + (metadata () ((a bool) (b (bv 4))) ()) + (bvxor (ite a #x0 #x1) b) + ) + ) + (main + (computation + (metadata () ((c bool)) ()) + (bvand #xf ((field 0) ( (call myxor (a b) (bool (bv 4)) ((bv 4))) c #x4 ))) + ) + ) + ) + )", + ); + + let expected = text::parse_computation( + b" + (computation + (metadata () ((c bool)) ()) + (bvand #xf ((field 0) (tuple (bvxor (ite c #x0 #x1) #x4)))) + ) + ", + ); + link_all_function_calls(&mut fs); + let c = fs.get_comp("main").unwrap().clone(); + assert_eq!(c, expected); + } + + #[test] + fn nested_calls() { + let mut fs = text::parse_functions( + b" + (functions + ( + (foo + (computation + (metadata () ((a bool)) ()) + (not a) + ) + ) + (bar + (computation + (metadata () ((a bool)) ()) + (xor ((field 0) ((call foo (a) (bool) (bool)) a)) true) + ) + ) + (main + (computation + (metadata () ((a bool) (b bool)) ()) + ((field 0) ((call bar (a) (bool) (bool)) a)) + ) + ) + ) + )", + ); + + let expected = text::parse_computation( + b" + (computation + (metadata () ((a bool) (b bool)) ()) + ((field 0) (tuple (xor ((field 0) (tuple (not a))) true))) + ) + ", + ); + link_all_function_calls(&mut fs); + let c = fs.get_comp("main").unwrap().clone(); + assert_eq!(c, expected); + } + + #[test] + fn multiple_calls() { + let mut fs = text::parse_functions( + b" + (functions + ( + (foo + (computation + (metadata () ((a bool)) ()) + (not a) + ) + ) + (bar + (computation + (metadata () ((a bool)) ()) + (xor ((field 0) ((call foo (a) (bool) (bool)) a)) true) + ) + ) + (main + (computation + (metadata () ((a bool) (b bool)) ()) + (and + ((field 0) ((call foo (a) (bool) (bool)) a)) + ((field 0) ((call foo (a) (bool) (bool)) b)) + ((field 0) ((call bar (a) (bool) (bool)) a)) + ) + ) + ) + ) + )", + ); + + let expected = text::parse_computation( + b" + (computation + (metadata () ((a bool) (b bool)) ()) + (and + ((field 0) (tuple (not a))) + ((field 0) (tuple (not b))) + ((field 0) (tuple (xor ((field 0) (tuple (not a))) true))) + ) + ) + ", + ); + link_all_function_calls(&mut fs); + let c = fs.get_comp("main").unwrap().clone(); + assert_eq!(c, expected); } } diff --git a/src/ir/opt/mem/lin.rs b/src/ir/opt/mem/lin.rs index 395732d9..e5724cb8 100644 --- a/src/ir/opt/mem/lin.rs +++ b/src/ir/opt/mem/lin.rs @@ -33,7 +33,7 @@ impl RewritePass for Linearizer { match &orig.op { Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(super::arr_val_to_tup(v)))), Op::Var(_name, s) if super::sort_contains_array(s) => Some(super::array_to_tuple(orig)), - Op::Call(_name, _arg_names, arg_sorts, ret_sort) => { + Op::Call(_name, _arg_names, arg_sorts, ret_sorts) => { let mut args = rewritten_children(); for (a, s) in args.iter_mut().zip(arg_sorts) { if super::sort_contains_array(s) { @@ -41,7 +41,7 @@ impl RewritePass for Linearizer { } } let out = term(orig.op.clone(), args); - Some(if super::sort_contains_array(ret_sort) { + Some(if ret_sorts.iter().any(super::sort_contains_array) { super::array_to_tuple(&out) } else { out diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index 0f7057d2..d358eabc 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -10,7 +10,7 @@ pub mod sha; pub mod tuple; mod visit; -use std::{collections::HashMap, time::Instant}; +use std::time::Instant; use super::term::*; @@ -48,6 +48,10 @@ pub enum Opt { pub fn opt>(mut fs: Functions, optimizations: I) -> Functions { for i in optimizations { let mut opt_fs: Functions = fs.clone(); + if let Opt::InlineCalls = i { + link::link_all_function_calls(&mut opt_fs); + continue + } for (name, comp) in fs.computations.iter_mut() { debug!("Applying: {:?} to {}", i, name); let now = Instant::now(); @@ -112,12 +116,7 @@ pub fn opt>(mut fs: Functions, optimizations: I) -> .collect(); inline::inline(&mut comp.outputs, &public_inputs); } - Opt::InlineCalls => { - let mut cache = link::Cache::new(); - for a in &mut comp.outputs { - *a = link::link_function_calls(a.clone(), &mut cache, &opt_fs); - } - } + Opt::InlineCalls => unreachable!(), Opt::Tuple => { tuple::eliminate_tuples(comp); } diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index a637b716..5854f2fa 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -141,8 +141,10 @@ pub enum Op { /// Map (operation) Map(Box), - /// Call a function (name, argument names, argument sorts, return sort) - Call(String, Vec, Vec, Sort), + /// Call a function (name, argument names, argument sorts, return sorts) + /// + /// Note that the type of this term is always a tuple. + Call(String, Vec, Vec, Vec), } /// Boolean AND @@ -298,7 +300,7 @@ impl Display for Op { Op::Field(i) => write!(f, "(field {})", i), Op::Update(i) => write!(f, "(update {})", i), Op::Map(op) => write!(f, "(map({}))", op), - Op::Call(name, arg_names, arg_sorts, sort) => { + Op::Call(name, arg_names, arg_sorts, ret_sorts) => { write!(f, "(call {} (", name)?; for arg_name in arg_names { write!(f, " {}", arg_name)?; @@ -307,7 +309,11 @@ impl Display for Op { for arg_sort in arg_sorts { write!(f, " {}", arg_sort)?; } - write!(f, ") {})", sort) + write!(f, ") (")?; + for ret_sort in ret_sorts { + write!(f, " {}", ret_sort)?; + } + write!(f, "))") } } } @@ -2020,8 +2026,8 @@ impl Functions { } /// Get the first computation by function name - pub fn get_comp(&self, name: String) -> Option<&Computation> { - self.computations.get(&name) + pub fn get_comp(&self, name: &str) -> Option<&Computation> { + self.computations.get(name) } /// Create a computation with a single entry function diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 362fc0c8..42fc7ff1 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -46,7 +46,7 @@ //! * Operator `O`: //! * Plain operators: (`bvmul`, `and`, ...) //! * Composite operators: `(field N)`, `(update N)`, `(sext N)`, `(uext N)`, `(bit N)`, ... -//! * call operator: `(call X (X1 ... XN) (S1 ... SN) S)` +//! * call operator: `(call X (X1 ... XN) (S1 ... SN) (RS1 ... RSN))` use circ_fields::{FieldT, FieldV}; @@ -275,12 +275,12 @@ impl<'src> IrInterp<'src> { [Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))), [Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))), [Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))), - [Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, sort] => { + [Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, ret_sorts] => { let name = from_utf8(name).unwrap().to_owned(); let arg_names = self.string_list(arg_names); let arg_sorts = self.sort_list(arg_sorts); - let sort = self.sort(sort); - Ok(Op::Call(name, arg_names, arg_sorts, sort)) + let ret_sorts = self.sort_list(ret_sorts); + Ok(Op::Call(name, arg_names, arg_sorts, ret_sorts)) } _ => todo!("Unparsed op: {}", tt), }, @@ -777,7 +777,7 @@ mod test { let t = parse_term( b" (declare ((a bool)) - ( (call myxor (a b) (bool bool) bool) a a ) + ((field 0) ( (call myxor (a b) (bool bool) (bool)) a a )) )", ); assert_eq!(check(&t), Sort::Bool); @@ -918,7 +918,7 @@ mod test { (main (computation (metadata () ((a bool) (b bool)) ()) - (and false ( (call myxor (a b) (bool bool) bool) a b )) + (and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b ))) ) ) ) diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index a45523b2..3d901917 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -175,7 +175,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { } } } - Op::Call(_, _, _, ret) => Ok(ret.clone()), + Op::Call(_, _, _, rets) => Ok(Sort::Tuple(rets.clone().into())), o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))), } } @@ -392,14 +392,14 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result { + (Op::Call(_, _, ex_args, rets), act_args) => { if ex_args.len() != act_args.len() { Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len())) } else { for (e, a) in ex_args.iter().zip(act_args) { eq_or(e, a, "in function call")?; } - Ok(ret.clone()) + Ok(Sort::Tuple(rets.clone().into())) } } (_, _) => Err(TypeErrorReason::Custom("other".to_string())), diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index d2c6120a..5f845153 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -553,11 +553,11 @@ impl<'a> ToABY<'a> { self.term_to_shares.insert(t.clone(), vec![shares[*i]]); self.cache.insert(t.clone(), EmbeddedTerm::Array); } - Op::Call(name, arg_names, _arg_sorts, ret) => { + Op::Call(name, arg_names, _arg_sorts, ret_sorts) => { let shares = self.get_shares(&t); let op = format!("CALL({})", name); let num_args = arg_names.len(); - let num_rets = self.get_sort_len(ret); + let num_rets: usize = ret_sorts.iter().map(|ret| self.get_sort_len(ret)).sum(); // map argument shares let mut arg_shares: Vec = Vec::new(); From 3fbf5fa21167f20f9b8578e060cadfd11dbb3e3d Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Wed, 15 Jun 2022 22:20:41 -0700 Subject: [PATCH 2/3] clean up oblivious pass --- src/ir/opt/mem/mod.rs | 3 +- src/ir/opt/mem/obliv.rs | 347 +++++++++++++++++++--------------------- 2 files changed, 165 insertions(+), 185 deletions(-) diff --git a/src/ir/opt/mem/mod.rs b/src/ir/opt/mem/mod.rs index c35521fd..803bee01 100644 --- a/src/ir/opt/mem/mod.rs +++ b/src/ir/opt/mem/mod.rs @@ -59,6 +59,7 @@ fn sort_contains_array(s: &Sort) -> bool { /// Given a sort `s` which may contain array constructors, construct a new sort in which the arrays /// have been flattened to tuples. +#[allow(dead_code)] fn array_to_tuple_sort(s: &Sort) -> Sort { match s { Sort::Tuple(ss) => Sort::Tuple(ss.iter().map(array_to_tuple_sort).collect()), @@ -70,7 +71,7 @@ fn array_to_tuple_sort(s: &Sort) -> Sort { /// Given a term of tuples, re-shape into sort `s`. fn resort(t: &Term, s: &Sort) -> Term { match s { - Sort::Array(k, v, sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold( + Sort::Array(k, v, _sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold( s.default_term(), |acc, (t, i)| term![Op::Store; acc, i, resort(&t, v)], ), diff --git a/src/ir/opt/mem/obliv.rs b/src/ir/opt/mem/obliv.rs index 26a79931..d955c31b 100644 --- a/src/ir/opt/mem/obliv.rs +++ b/src/ir/opt/mem/obliv.rs @@ -11,24 +11,22 @@ //! //! ## Pass 1: Identifying oblivious arrays //! +//! //! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation -//! system, performing the following inferences: +//! system, marking som arrays as non-oblivious. Only non-constant terms with a sort that contains +//! an array can be non-oblivious. //! -//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious; -//! * If `a[i]`, `a` and `a[i]` are equi-oblivious -//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious -//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious -//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious -//! * If `a=b`, then `a` and `b` are equi-oblivious -//! * If any other term is array-valued, that array is non-oblivious. Including: -//! * variables -//! * function call outputs -//! * tuple fields //! -//! This procedure is iterated to fixpoint. +//! The following are *directly marked* as non-oblivious: //! -//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select -//! terms as non-oblivious. This makes it easy to see which selects should be replaced later. +//! * Selects with non-constant indices +//! * Stores with non-constant indices +//! * Any term containing an array sort that is not a array of primitives +//! * Variables, call arguments, call outputs, and the computation outputs +//! +//! All terms propagate non-obliviousness to their children and recieve it from their children. +//! +//! Propagation is repeated to fixpoint. //! //! ### Sharing & Constant Arrays //! @@ -69,130 +67,89 @@ //! * map array terms to tuple terms //! * map array selections to tuple field gets //! -//! In both cases we look at the non-oblivious array/select set to see whether to do the -//! replacement. -//! +//! In both cases we look at the non-oblivious array set to see whether to do the replacement. use super::super::visit::*; use crate::ir::term::extras::as_uint_constant; use crate::ir::term::*; -use log::{debug, trace}; +use log::trace; +fn is_prim_array(t: &Sort) -> bool { + if let Sort::Array(k, v, _) = t { + k.is_scalar() && v.is_scalar() + } else { + false + } +} + +fn is_markable(t: &Term, s: &Sort) -> bool { + !t.is_const() && !s.is_scalar() && super::sort_contains_array(s) +} + +#[derive(Default)] struct NonOblivComputer { not_obliv: TermSet, } impl NonOblivComputer { fn mark(&mut self, a: &Term) -> bool { - if !a.is_const() && self.not_obliv.insert(a.clone()) { - trace!("Not obliv: {}", a); + let s = check(a); + if is_markable(a, &s) && self.not_obliv.insert(a.clone()) { + trace!("Not obliv: {}", extras::Letified(a.clone())); true } else { false } } - fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool { - if !a.is_const() && !b.is_const() { - match (self.not_obliv.contains(a), self.not_obliv.contains(b)) { - (false, true) => { - self.not_obliv.insert(a.clone()); - true - } - (true, false) => { - self.not_obliv.insert(b.clone()); - true - } - _ => false, + fn propagate(&mut self, a: &Term) -> bool { + if self.not_obliv.contains(a) || a.cs.iter().any(|c| self.not_obliv.contains(c)) { + let mut progress = false; + progress |= self.mark(a); + for c in &a.cs { + progress |= self.mark(c); } + progress } else { false } } - - fn new() -> Self { - Self { - not_obliv: TermSet::new(), - } - } } impl ProgressAnalysisPass for NonOblivComputer { fn visit(&mut self, term: &Term) -> bool { - match &term.op { - Op::Store => { - let a = &term.cs[0]; - let i = &term.cs[1]; - let v = &term.cs[2]; - let mut progress = false; - if let Sort::Array(..) = check(v) { - // Imprecisely, mark v as non-obliv iff the array is. - progress = self.bi_implicate(term, v) || progress; - } - if i.is_const() { - progress = self.bi_implicate(term, a) || progress; - } else { - progress = self.mark(a) || progress; - progress = self.mark(term) || progress; - } - if let Sort::Array(..) = check(v) { - // Imprecisely, mark v as non-obliv iff the array is. - progress = self.bi_implicate(term, v) || progress; - } - progress - } - Op::Select => { - // Even though the selected value may not have array sort, we still flag it as - // non-oblivious so we know whether to replace it or not. - let a = &term.cs[0]; - let i = &term.cs[1]; - let mut progress = false; - if i.is_const() { - // pass - } else { - progress = self.mark(a) || progress; - progress = self.mark(term) || progress; - } - progress = self.bi_implicate(term, a) || progress; - progress - } - Op::Ite => { - let t = &term.cs[1]; - let f = &term.cs[2]; - if let Sort::Array(..) = check(t) { - let mut progress = self.bi_implicate(term, t); - progress = self.bi_implicate(t, f) || progress; - progress = self.bi_implicate(term, f) || progress; - progress - } else { - false - } - } - Op::Eq => { - let a = &term.cs[0]; - let b = &term.cs[1]; - if let Sort::Array(..) = check(a) { - self.bi_implicate(a, b) - } else { - false - } - } - // constants are oblivious - Op::Const(..) => false, - _ => match check(term) { - Sort::Array(..) => { - // variables, fields, and function outputs are non-oblivious. - debug_assert!( - matches!(term.op, Op::Call(..) | Op::Var(..) | Op::Field(..)), - "Unexpected array term {}", - term - ); - self.mark(term) - } + let sort = check(term); + let mut progress = false; + + // First, we directly mark this term if + // (a) its sort contains an array and is not an array of primitives. + let do_mark = if !sort.is_scalar() && !is_prim_array(&sort) { + true + } else { + // (b) it is a select or store with a non-constant index OR a call + match &term.op { + Op::Store if !term.cs[1].is_const() => true, + Op::Select if !term.cs[1].is_const() => true, + Op::Var(..) => true, + Op::Call(..) => true, _ => false, - }, + } + }; + if do_mark { + progress |= self.mark(term); } + + // Now, we mark the children of calls + if let Op::Call(..) = &term.op { + for c in &term.cs { + progress |= self.mark(c); + } + } + + // Finally, we propagate marks + progress |= self.propagate(term); + progress } } @@ -202,7 +159,7 @@ struct Replacer { } impl Replacer { - fn should_replace(&self, a: &Term) -> bool { + fn is_obliv(&self, a: &Term) -> bool { !self.not_obliv.contains(a) } } @@ -229,42 +186,29 @@ impl RewritePass for Replacer { .map(super::term_arr_val_to_tup) .collect() }; - debug!("visiting {}", orig); + trace!("rewriting {}", extras::Letified(orig.clone())); match &orig.op { - Op::Select => { - // we mark the selected term as non-obliv... - if self.should_replace(orig) { - let mut cs = get_cs(); - debug_assert_eq!(cs.len(), 2); - let k_const = get_const(&cs.pop().unwrap()); - Some(term(Op::Field(k_const), cs)) - } else { - None - } + Op::Select if self.is_obliv(&orig.cs[0]) => { + trace!(" is oblivious"); + let mut cs = get_cs(); + debug_assert_eq!(cs.len(), 2); + let k_const = get_const(&cs.pop().unwrap()); + Some(term(Op::Field(k_const), cs)) } - Op::Store => { - if self.should_replace(orig) { - let mut cs = get_cs(); - debug_assert_eq!(cs.len(), 3); - let k_const = get_const(&cs.remove(1)); - Some(term(Op::Update(k_const), cs)) - } else { - None - } + Op::Store if self.is_obliv(orig) => { + trace!(" is oblivious"); + let mut cs = get_cs(); + debug_assert_eq!(cs.len(), 3); + let k_const = get_const(&cs.remove(1)); + Some(term(Op::Update(k_const), cs)) } - Op::Ite => { - if self.should_replace(orig) { - Some(term(Op::Ite, get_cs())) - } else { - None - } + Op::Ite if self.is_obliv(orig) => { + trace!(" is oblivious"); + Some(term(Op::Ite, get_cs())) } - Op::Eq => { - if self.should_replace(&orig.cs[0]) { - Some(term(Op::Eq, get_cs())) - } else { - None - } + Op::Eq if self.is_obliv(&orig.cs[0]) => { + trace!(" is oblivious"); + Some(term(Op::Eq, get_cs())) } _ => None, } @@ -273,7 +217,10 @@ impl RewritePass for Replacer { /// Eliminate oblivious arrays. See module documentation. pub fn elim_obliv(t: &mut Computation) { - let mut prop_pass = NonOblivComputer::new(); + let mut prop_pass = NonOblivComputer::default(); + for o in t.outputs() { + prop_pass.mark(o); + } prop_pass.traverse(t); let mut replace_pass = Replacer { not_obliv: prop_pass.not_obliv, @@ -281,7 +228,9 @@ pub fn elim_obliv(t: &mut Computation) { let initial_output_sorts: Vec = t.outputs.iter().map(check).collect(); ::traverse(&mut replace_pass, t); for (o, s) in t.outputs.iter_mut().zip(initial_output_sorts) { - *o = super::resort(o, &s) + if check(o) != s { + *o = super::resort(o, &s) + } } } @@ -402,7 +351,9 @@ mod test { #[test] fn ignore_vars() { - let before = text::parse_computation(b" + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" (computation (metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ()) (bvadd @@ -410,8 +361,10 @@ mod test { (select (store (#a (bv 8) #x00 4 ()) #x00 a) #x00) (select (store (#a (bv 8) #x01 5 ()) #x00 a) #x01) ) - )"); - let expected = text::parse_computation(b" + )", + ); + let expected = text::parse_computation( + b" (computation (metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ()) (bvadd @@ -419,7 +372,8 @@ mod test { ((field 0) ((update 0) (#t #x00 #x00 #x00 #x00) a)) ((field 1) ((update 0) (#t #x01 #x01 #x01 #x01 #x01) a)) ) - )"); + )", + ); let mut after = before.clone(); elim_obliv(&mut after); assert_eq!(after, expected); @@ -427,35 +381,18 @@ mod test { #[test] fn preserve_output_type() { - let before = text::parse_computation(b" + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" (computation (metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ()) (store (store (#a (bv 8) #x00 3 ()) #x01 (select A #x00)) #x00 #x05 ) - )"); - let expected = text::parse_computation(b" - (computation - (metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ()) - (let ((tupout - ((update 0) - ((update 1) - (#t #x00 #x00 #x00) - (select A #x00) - ) - #x05 - ) - )) - (store - (store - (store (#a (bv 8) #x00 3 ()) #x00 ((field 0) tupout)) - #x01 ((field 1) tupout) - ) - #x02 ((field 2) tupout) - ) - ) - )"); + )", + ); + let expected = before.clone(); let mut after = before.clone(); elim_obliv(&mut after); assert_eq!(after, expected); @@ -463,25 +400,67 @@ mod test { #[test] fn identity_fn() { - let before = text::parse_computation(b" + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" (computation (metadata () ((A (array (bv 8) (bv 8) 3))) ()) A - )"); - let expected = text::parse_computation(b" - (computation - (metadata () ((A (array (bv 8) (bv 8) 3))) ()) - (store - (store - (store - (#a (bv 8) #x00 3 ( )) - #x00 (select A #x00)) - #x01 (select A #x01)) - #x02 (select A #x02)) - )"); + )", + ); + let expected = before.clone(); let mut after = before.clone(); elim_obliv(&mut after); assert_eq!(after, expected); } + #[test] + fn ignore_call_inputs() { + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" + (computation + (metadata () () ()) + ((call foo (a) ((array (bv 8) (bv 8) 4)) (bool)) (store (#a (bv 8) #x00 4 ()) #x00 #x01)) + )", + ); + let expected = before.clone(); + let mut after = before.clone(); + elim_obliv(&mut after); + assert_eq!(after, expected); + } + + #[test] + fn ignore_call_outputs() { + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" + (computation + (metadata () () ()) + (select + ((field 0) ((call foo () () ((array (bv 8) (bv 8) 4))) )) + #x00) + )", + ); + let expected = before.clone(); + let mut after = before.clone(); + elim_obliv(&mut after); + assert_eq!(after, expected); + } + + #[test] + fn ignore_outputs() { + let _ = env_logger::builder().is_test(true).try_init(); + let before = text::parse_computation( + b" + (computation + (metadata () () ()) + (store (#a (bv 8) #x00 4 ()) #x00 #x01) + )", + ); + let expected = before.clone(); + let mut after = before.clone(); + elim_obliv(&mut after); + assert_eq!(after, expected); + } } From 10ec8e942945881ff41a3a8a7b656caf561ec069 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Wed, 15 Jun 2022 22:22:18 -0700 Subject: [PATCH 3/3] fix warning --- src/ir/opt/mem/lin.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/opt/mem/lin.rs b/src/ir/opt/mem/lin.rs index e5724cb8..a31da629 100644 --- a/src/ir/opt/mem/lin.rs +++ b/src/ir/opt/mem/lin.rs @@ -26,7 +26,7 @@ struct Linearizer; impl RewritePass for Linearizer { fn visit Vec>( &mut self, - computation: &mut Computation, + _computation: &mut Computation, orig: &Term, rewritten_children: F, ) -> Option {