mirror of
https://github.com/circify/circ.git
synced 2026-05-14 03:00:33 -04:00
Merge branch 'function_calls' of github.com:circify/circ into function_calls
This commit is contained in:
@@ -64,8 +64,8 @@ impl FrontEnd for C {
|
||||
// generate new context
|
||||
g.circ = Circify::new(Ct::new());
|
||||
let call = g.function_queue.pop().unwrap();
|
||||
if let Op::Call(name, arg_names, arg_sorts, rets) = &call.op {
|
||||
g.fn_call(name, arg_names, arg_sorts, rets);
|
||||
if let Op::Call(name, arg_names, arg_sorts, ret_sorts) = &call.op {
|
||||
g.fn_call(name, arg_names, arg_sorts, ret_sorts);
|
||||
let comp = g.circ.consume().borrow().clone();
|
||||
|
||||
// println!("fn: {}", name);
|
||||
@@ -839,7 +839,7 @@ impl CGen {
|
||||
name.clone(),
|
||||
arg_names.clone(),
|
||||
arg_sorts.clone(),
|
||||
Sort::Tuple(ret_sorts.clone().into_boxed_slice()),
|
||||
ret_sorts.clone(),
|
||||
),
|
||||
arg_terms.clone().into_iter().flatten().collect::<Vec<_>>(),
|
||||
);
|
||||
@@ -1198,7 +1198,7 @@ impl CGen {
|
||||
name: &String,
|
||||
arg_names: &Vec<String>,
|
||||
arg_sorts: &Vec<Sort>,
|
||||
rets: &Sort,
|
||||
ret_sorts: &Vec<Sort>,
|
||||
) {
|
||||
debug!("Call: {}", name);
|
||||
println!("Call: {}", name);
|
||||
|
||||
@@ -1,167 +1,263 @@
|
||||
//! Inline function call terms
|
||||
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use fxhash::FxHashMap as HashMap;
|
||||
|
||||
use crate::ir::term::*;
|
||||
use crate::ir::opt::visit::RewritePass;
|
||||
|
||||
/// Inline cache.
|
||||
#[derive(Default)]
|
||||
pub struct Cache(TermMap<Term>);
|
||||
|
||||
impl Cache {
|
||||
/// Empty cache.
|
||||
pub fn new() -> Self {
|
||||
Cache(TermMap::new())
|
||||
}
|
||||
/// A recursive inliner.
|
||||
struct Inliner<'f> {
|
||||
/// Original source for functions
|
||||
fs: &'f Functions,
|
||||
/// Map from names to call-free computations
|
||||
cache: HashMap<String, Computation>,
|
||||
}
|
||||
|
||||
// TODO: this can fail if the function name contains '_'
|
||||
fn get_var_name(name: &String) -> String {
|
||||
let new_name = name.to_string().replace('.', "_");
|
||||
let n = new_name.split('_').collect::<Vec<&str>>();
|
||||
match n.len() {
|
||||
5 => n[3].to_string(),
|
||||
6.. => {
|
||||
let l = n.len() - 1;
|
||||
format!("{}_{}", n[l - 2], n[l])
|
||||
}
|
||||
_ => {
|
||||
panic!("Invalid variable name: {}", name);
|
||||
/// Compute the term that corresponds to a function call.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `arg_names`: the argument names, in order
|
||||
/// * `arg_values`: the argument values, in the same order
|
||||
/// * `callee`: the called function
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// A (tuple) term that corresponds to the output on those inputs
|
||||
///
|
||||
/// ## Note
|
||||
///
|
||||
/// This function **does not** recursively inline.
|
||||
fn inline_one(arg_names: &Vec<String>, arg_values: Vec<Term>, callee: &Computation) -> Term {
|
||||
let mut sub_map: TermMap<Term> = arg_names
|
||||
.into_iter()
|
||||
.zip(arg_values)
|
||||
.map(|(n, v)| {
|
||||
let s = callee.metadata.input_sort(n).clone();
|
||||
(leaf_term(Op::Var(n.clone(), s)), v)
|
||||
})
|
||||
.collect();
|
||||
term(
|
||||
Op::Tuple,
|
||||
callee
|
||||
.outputs()
|
||||
.iter()
|
||||
.map(|o| extras::substitute_cache(o, &mut sub_map))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
impl<'f> Inliner<'f> {
|
||||
/// Ensure that a totally inlined version of `name` is in the cache.
|
||||
fn inline_all(&mut self, name: &str) {
|
||||
if !self.cache.contains_key(name) {
|
||||
let mut c = self.fs.get_comp(name).unwrap().clone();
|
||||
for t in c.terms_postorder() {
|
||||
if let Op::Call(callee_name, ..) = &t.op {
|
||||
self.inline_all(callee_name);
|
||||
}
|
||||
}
|
||||
self.traverse(&mut c);
|
||||
let present = self.cache.insert(name.into(), c);
|
||||
assert!(present.is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn match_arg(name: &String, params: &BTreeMap<String, Term>) -> Term {
|
||||
let new_name = get_var_name(name);
|
||||
params.get(&new_name).unwrap().clone()
|
||||
}
|
||||
|
||||
fn link(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Term> {
|
||||
let mut res: Vec<Term> = Vec::new();
|
||||
let comp = fs.computations.get(name).unwrap();
|
||||
for o in comp.outputs.iter() {
|
||||
let mut cache = TermMap::new();
|
||||
for t in PostOrderIter::new(o.clone()) {
|
||||
match &t.op {
|
||||
Op::Var(name, _) => {
|
||||
let ret = match_arg(name, ¶ms);
|
||||
cache.insert(t.clone(), ret.clone());
|
||||
}
|
||||
_ => {
|
||||
let mut children = Vec::new();
|
||||
for c in &t.cs {
|
||||
if let Some(rewritten_c) = cache.get(c) {
|
||||
children.push(rewritten_c.clone());
|
||||
} else {
|
||||
children.push(c.clone());
|
||||
}
|
||||
}
|
||||
cache.insert(t.clone(), term(t.op.clone(), children));
|
||||
}
|
||||
}
|
||||
/// Rewrites a term, inlining function calls along the way.
|
||||
///
|
||||
/// Assumes that the callees are already inlined. Panics otherwise.
|
||||
impl<'f> RewritePass for Inliner<'f> {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
_computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
if let Op::Call(fn_name, arg_names, _, _) = &orig.op {
|
||||
let callee = self.cache.get(fn_name).expect("missing inlined callee");
|
||||
let term = inline_one(arg_names, rewritten_children(), callee);
|
||||
Some(term)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
res.push(cache.get(o).unwrap().clone());
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
/// Traverse terms and link function calls
|
||||
pub fn link_function_calls(
|
||||
term_: Term,
|
||||
Cache(ref mut rewritten): &mut Cache,
|
||||
fs: &Functions,
|
||||
) -> Term {
|
||||
let mut call_cache: HashMap<Term, Vec<Term>> = HashMap::new();
|
||||
for t in PostOrderIter::new(term_.clone()) {
|
||||
let mut children = Vec::new();
|
||||
for c in &t.cs {
|
||||
if let Some(rewritten_c) = rewritten.get(c) {
|
||||
children.push(rewritten_c.clone());
|
||||
} else {
|
||||
children.push(c.clone());
|
||||
}
|
||||
}
|
||||
let entry = match &t.op {
|
||||
Op::Field(index) => {
|
||||
assert!(t.cs.len() > 0);
|
||||
if let Op::Call(..) = &t.cs[0].op {
|
||||
if call_cache.contains_key(&t.cs[0]) {
|
||||
call_cache.get(&t.cs[0]).unwrap()[*index].clone()
|
||||
} else {
|
||||
panic!("Fields on a Call term should return");
|
||||
}
|
||||
} else {
|
||||
term(t.op.clone(), children)
|
||||
}
|
||||
}
|
||||
Op::Call(name, arg_names, arg_sorts, _) => {
|
||||
println!("Inlining: {}", name);
|
||||
|
||||
// Check number of args
|
||||
let num_args = arg_sorts.iter().fold(0, |sum, x| {
|
||||
sum + match x {
|
||||
Sort::Array(_, _, l) => *l,
|
||||
_ => 1,
|
||||
}
|
||||
});
|
||||
assert!(
|
||||
num_args == t.cs.len(),
|
||||
"Number of arguments mismatch. {}, {}",
|
||||
num_args,
|
||||
t.cs.len()
|
||||
);
|
||||
|
||||
// Check arg types
|
||||
let arg_types = arg_sorts
|
||||
.iter()
|
||||
.map(|x| match &x {
|
||||
Sort::Array(_, val_sort, l) => {
|
||||
let mut res: Vec<Sort> = Vec::new();
|
||||
for _ in 0..*l {
|
||||
res.push(*val_sort.clone());
|
||||
}
|
||||
res
|
||||
}
|
||||
_ => vec![x.clone()],
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert!(
|
||||
arg_types == t.cs.iter().map(|c| check(c)).collect::<Vec<Sort>>(),
|
||||
"Argument type mismatch"
|
||||
);
|
||||
|
||||
let mut params: BTreeMap<String, Term> = BTreeMap::new();
|
||||
let arg_keys = arg_names
|
||||
.iter()
|
||||
.zip(arg_sorts.iter())
|
||||
.map(|(n, x)| match &x {
|
||||
Sort::Array(_, _, l) => {
|
||||
let mut res: Vec<String> = Vec::new();
|
||||
for i in 0..*l {
|
||||
res.push(format!("{}_{}", n.clone(), i));
|
||||
}
|
||||
res
|
||||
}
|
||||
_ => vec![n.clone()],
|
||||
})
|
||||
.flatten();
|
||||
for (n, c) in arg_keys.zip(t.cs.clone()) {
|
||||
params.insert(n.clone(), c.clone());
|
||||
}
|
||||
let res = link(name, params, fs);
|
||||
call_cache.insert(t.clone(), res.clone());
|
||||
res[0].clone()
|
||||
}
|
||||
_ => term(t.op.clone(), children),
|
||||
};
|
||||
rewritten.insert(t.clone(), entry);
|
||||
}
|
||||
|
||||
if let Some(t) = rewritten.get(&term_) {
|
||||
t.clone()
|
||||
} else {
|
||||
panic!("Couldn't find rewritten binarized term: {}", term_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Inline all calls within a function set.
|
||||
pub fn link_all_function_calls(fs: &mut Functions) {
|
||||
let mut inliner = Inliner {
|
||||
fs,
|
||||
cache: Default::default(),
|
||||
};
|
||||
for name in fs.computations.keys() {
|
||||
inliner.inline_all(name);
|
||||
}
|
||||
*fs = Functions {
|
||||
computations: inliner.cache.into_iter().collect(),
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bool_arg_nonrec() {
|
||||
let mut fs = text::parse_functions(
|
||||
b"
|
||||
(functions
|
||||
(
|
||||
(myxor
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(xor a b false false)
|
||||
)
|
||||
)
|
||||
(main
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b )))
|
||||
)
|
||||
)
|
||||
)
|
||||
)",
|
||||
);
|
||||
let expected = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(and false ((field 0) (tuple (xor a b false false))))
|
||||
)
|
||||
",
|
||||
);
|
||||
link_all_function_calls(&mut fs);
|
||||
let c = fs.get_comp("main").unwrap().clone();
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scalar_arg_nonrec() {
|
||||
let mut fs = text::parse_functions(
|
||||
b"
|
||||
(functions
|
||||
(
|
||||
(myxor
|
||||
(computation
|
||||
(metadata () ((a bool) (b (bv 4))) ())
|
||||
(bvxor (ite a #x0 #x1) b)
|
||||
)
|
||||
)
|
||||
(main
|
||||
(computation
|
||||
(metadata () ((c bool)) ())
|
||||
(bvand #xf ((field 0) ( (call myxor (a b) (bool (bv 4)) ((bv 4))) c #x4 )))
|
||||
)
|
||||
)
|
||||
)
|
||||
)",
|
||||
);
|
||||
|
||||
let expected = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((c bool)) ())
|
||||
(bvand #xf ((field 0) (tuple (bvxor (ite c #x0 #x1) #x4))))
|
||||
)
|
||||
",
|
||||
);
|
||||
link_all_function_calls(&mut fs);
|
||||
let c = fs.get_comp("main").unwrap().clone();
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_calls() {
|
||||
let mut fs = text::parse_functions(
|
||||
b"
|
||||
(functions
|
||||
(
|
||||
(foo
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(not a)
|
||||
)
|
||||
)
|
||||
(bar
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(xor ((field 0) ((call foo (a) (bool) (bool)) a)) true)
|
||||
)
|
||||
)
|
||||
(main
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
((field 0) ((call bar (a) (bool) (bool)) a))
|
||||
)
|
||||
)
|
||||
)
|
||||
)",
|
||||
);
|
||||
|
||||
let expected = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
((field 0) (tuple (xor ((field 0) (tuple (not a))) true)))
|
||||
)
|
||||
",
|
||||
);
|
||||
link_all_function_calls(&mut fs);
|
||||
let c = fs.get_comp("main").unwrap().clone();
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_calls() {
|
||||
let mut fs = text::parse_functions(
|
||||
b"
|
||||
(functions
|
||||
(
|
||||
(foo
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(not a)
|
||||
)
|
||||
)
|
||||
(bar
|
||||
(computation
|
||||
(metadata () ((a bool)) ())
|
||||
(xor ((field 0) ((call foo (a) (bool) (bool)) a)) true)
|
||||
)
|
||||
)
|
||||
(main
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(and
|
||||
((field 0) ((call foo (a) (bool) (bool)) a))
|
||||
((field 0) ((call foo (a) (bool) (bool)) b))
|
||||
((field 0) ((call bar (a) (bool) (bool)) a))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)",
|
||||
);
|
||||
|
||||
let expected = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(and
|
||||
((field 0) (tuple (not a)))
|
||||
((field 0) (tuple (not b)))
|
||||
((field 0) (tuple (xor ((field 0) (tuple (not a))) true)))
|
||||
)
|
||||
)
|
||||
",
|
||||
);
|
||||
link_all_function_calls(&mut fs);
|
||||
let c = fs.get_comp("main").unwrap().clone();
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ struct Linearizer;
|
||||
impl RewritePass for Linearizer {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
_computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
match &orig.op {
|
||||
Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(super::arr_val_to_tup(v)))),
|
||||
Op::Var(_name, s) if super::sort_contains_array(s) => Some(super::array_to_tuple(orig)),
|
||||
Op::Call(_name, _arg_names, arg_sorts, ret_sort) => {
|
||||
Op::Call(_name, _arg_names, arg_sorts, ret_sorts) => {
|
||||
let mut args = rewritten_children();
|
||||
for (a, s) in args.iter_mut().zip(arg_sorts) {
|
||||
if super::sort_contains_array(s) {
|
||||
@@ -41,7 +41,7 @@ impl RewritePass for Linearizer {
|
||||
}
|
||||
}
|
||||
let out = term(orig.op.clone(), args);
|
||||
Some(if super::sort_contains_array(ret_sort) {
|
||||
Some(if ret_sorts.iter().any(super::sort_contains_array) {
|
||||
super::array_to_tuple(&out)
|
||||
} else {
|
||||
out
|
||||
|
||||
@@ -59,6 +59,7 @@ fn sort_contains_array(s: &Sort) -> bool {
|
||||
|
||||
/// Given a sort `s` which may contain array constructors, construct a new sort in which the arrays
|
||||
/// have been flattened to tuples.
|
||||
#[allow(dead_code)]
|
||||
fn array_to_tuple_sort(s: &Sort) -> Sort {
|
||||
match s {
|
||||
Sort::Tuple(ss) => Sort::Tuple(ss.iter().map(array_to_tuple_sort).collect()),
|
||||
@@ -70,7 +71,7 @@ fn array_to_tuple_sort(s: &Sort) -> Sort {
|
||||
/// Given a term of tuples, re-shape into sort `s`.
|
||||
fn resort(t: &Term, s: &Sort) -> Term {
|
||||
match s {
|
||||
Sort::Array(k, v, sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold(
|
||||
Sort::Array(k, v, _sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold(
|
||||
s.default_term(),
|
||||
|acc, (t, i)| term![Op::Store; acc, i, resort(&t, v)],
|
||||
),
|
||||
|
||||
@@ -11,24 +11,22 @@
|
||||
//!
|
||||
//! ## Pass 1: Identifying oblivious arrays
|
||||
//!
|
||||
//!
|
||||
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation
|
||||
//! system, performing the following inferences:
|
||||
//! system, marking som arrays as non-oblivious. Only non-constant terms with a sort that contains
|
||||
//! an array can be non-oblivious.
|
||||
//!
|
||||
//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious;
|
||||
//! * If `a[i]`, `a` and `a[i]` are equi-oblivious
|
||||
//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious
|
||||
//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious
|
||||
//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious
|
||||
//! * If `a=b`, then `a` and `b` are equi-oblivious
|
||||
//! * If any other term is array-valued, that array is non-oblivious. Including:
|
||||
//! * variables
|
||||
//! * function call outputs
|
||||
//! * tuple fields
|
||||
//!
|
||||
//! This procedure is iterated to fixpoint.
|
||||
//! The following are *directly marked* as non-oblivious:
|
||||
//!
|
||||
//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select
|
||||
//! terms as non-oblivious. This makes it easy to see which selects should be replaced later.
|
||||
//! * Selects with non-constant indices
|
||||
//! * Stores with non-constant indices
|
||||
//! * Any term containing an array sort that is not a array of primitives
|
||||
//! * Variables, call arguments, call outputs, and the computation outputs
|
||||
//!
|
||||
//! All terms propagate non-obliviousness to their children and recieve it from their children.
|
||||
//!
|
||||
//! Propagation is repeated to fixpoint.
|
||||
//!
|
||||
//! ### Sharing & Constant Arrays
|
||||
//!
|
||||
@@ -69,130 +67,89 @@
|
||||
//! * map array terms to tuple terms
|
||||
//! * map array selections to tuple field gets
|
||||
//!
|
||||
//! In both cases we look at the non-oblivious array/select set to see whether to do the
|
||||
//! replacement.
|
||||
//!
|
||||
//! In both cases we look at the non-oblivious array set to see whether to do the replacement.
|
||||
|
||||
use super::super::visit::*;
|
||||
use crate::ir::term::extras::as_uint_constant;
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::{debug, trace};
|
||||
use log::trace;
|
||||
|
||||
fn is_prim_array(t: &Sort) -> bool {
|
||||
if let Sort::Array(k, v, _) = t {
|
||||
k.is_scalar() && v.is_scalar()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_markable(t: &Term, s: &Sort) -> bool {
|
||||
!t.is_const() && !s.is_scalar() && super::sort_contains_array(s)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct NonOblivComputer {
|
||||
not_obliv: TermSet,
|
||||
}
|
||||
|
||||
impl NonOblivComputer {
|
||||
fn mark(&mut self, a: &Term) -> bool {
|
||||
if !a.is_const() && self.not_obliv.insert(a.clone()) {
|
||||
trace!("Not obliv: {}", a);
|
||||
let s = check(a);
|
||||
if is_markable(a, &s) && self.not_obliv.insert(a.clone()) {
|
||||
trace!("Not obliv: {}", extras::Letified(a.clone()));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool {
|
||||
if !a.is_const() && !b.is_const() {
|
||||
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
|
||||
(false, true) => {
|
||||
self.not_obliv.insert(a.clone());
|
||||
true
|
||||
}
|
||||
(true, false) => {
|
||||
self.not_obliv.insert(b.clone());
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
fn propagate(&mut self, a: &Term) -> bool {
|
||||
if self.not_obliv.contains(a) || a.cs.iter().any(|c| self.not_obliv.contains(c)) {
|
||||
let mut progress = false;
|
||||
progress |= self.mark(a);
|
||||
for c in &a.cs {
|
||||
progress |= self.mark(c);
|
||||
}
|
||||
progress
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
not_obliv: TermSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressAnalysisPass for NonOblivComputer {
|
||||
fn visit(&mut self, term: &Term) -> bool {
|
||||
match &term.op {
|
||||
Op::Store => {
|
||||
let a = &term.cs[0];
|
||||
let i = &term.cs[1];
|
||||
let v = &term.cs[2];
|
||||
let mut progress = false;
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
if i.is_const() {
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
} else {
|
||||
progress = self.mark(a) || progress;
|
||||
progress = self.mark(term) || progress;
|
||||
}
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Select => {
|
||||
// Even though the selected value may not have array sort, we still flag it as
|
||||
// non-oblivious so we know whether to replace it or not.
|
||||
let a = &term.cs[0];
|
||||
let i = &term.cs[1];
|
||||
let mut progress = false;
|
||||
if i.is_const() {
|
||||
// pass
|
||||
} else {
|
||||
progress = self.mark(a) || progress;
|
||||
progress = self.mark(term) || progress;
|
||||
}
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
progress
|
||||
}
|
||||
Op::Ite => {
|
||||
let t = &term.cs[1];
|
||||
let f = &term.cs[2];
|
||||
if let Sort::Array(..) = check(t) {
|
||||
let mut progress = self.bi_implicate(term, t);
|
||||
progress = self.bi_implicate(t, f) || progress;
|
||||
progress = self.bi_implicate(term, f) || progress;
|
||||
progress
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Op::Eq => {
|
||||
let a = &term.cs[0];
|
||||
let b = &term.cs[1];
|
||||
if let Sort::Array(..) = check(a) {
|
||||
self.bi_implicate(a, b)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
// constants are oblivious
|
||||
Op::Const(..) => false,
|
||||
_ => match check(term) {
|
||||
Sort::Array(..) => {
|
||||
// variables, fields, and function outputs are non-oblivious.
|
||||
debug_assert!(
|
||||
matches!(term.op, Op::Call(..) | Op::Var(..) | Op::Field(..)),
|
||||
"Unexpected array term {}",
|
||||
term
|
||||
);
|
||||
self.mark(term)
|
||||
}
|
||||
let sort = check(term);
|
||||
let mut progress = false;
|
||||
|
||||
// First, we directly mark this term if
|
||||
// (a) its sort contains an array and is not an array of primitives.
|
||||
let do_mark = if !sort.is_scalar() && !is_prim_array(&sort) {
|
||||
true
|
||||
} else {
|
||||
// (b) it is a select or store with a non-constant index OR a call
|
||||
match &term.op {
|
||||
Op::Store if !term.cs[1].is_const() => true,
|
||||
Op::Select if !term.cs[1].is_const() => true,
|
||||
Op::Var(..) => true,
|
||||
Op::Call(..) => true,
|
||||
_ => false,
|
||||
},
|
||||
}
|
||||
};
|
||||
if do_mark {
|
||||
progress |= self.mark(term);
|
||||
}
|
||||
|
||||
// Now, we mark the children of calls
|
||||
if let Op::Call(..) = &term.op {
|
||||
for c in &term.cs {
|
||||
progress |= self.mark(c);
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we propagate marks
|
||||
progress |= self.propagate(term);
|
||||
progress
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,7 +159,7 @@ struct Replacer {
|
||||
}
|
||||
|
||||
impl Replacer {
|
||||
fn should_replace(&self, a: &Term) -> bool {
|
||||
fn is_obliv(&self, a: &Term) -> bool {
|
||||
!self.not_obliv.contains(a)
|
||||
}
|
||||
}
|
||||
@@ -229,42 +186,29 @@ impl RewritePass for Replacer {
|
||||
.map(super::term_arr_val_to_tup)
|
||||
.collect()
|
||||
};
|
||||
debug!("visiting {}", orig);
|
||||
trace!("rewriting {}", extras::Letified(orig.clone()));
|
||||
match &orig.op {
|
||||
Op::Select => {
|
||||
// we mark the selected term as non-obliv...
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 2);
|
||||
let k_const = get_const(&cs.pop().unwrap());
|
||||
Some(term(Op::Field(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Op::Select if self.is_obliv(&orig.cs[0]) => {
|
||||
trace!(" is oblivious");
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 2);
|
||||
let k_const = get_const(&cs.pop().unwrap());
|
||||
Some(term(Op::Field(k_const), cs))
|
||||
}
|
||||
Op::Store => {
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
Some(term(Op::Update(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Op::Store if self.is_obliv(orig) => {
|
||||
trace!(" is oblivious");
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
Some(term(Op::Update(k_const), cs))
|
||||
}
|
||||
Op::Ite => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Op::Ite if self.is_obliv(orig) => {
|
||||
trace!(" is oblivious");
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
}
|
||||
Op::Eq => {
|
||||
if self.should_replace(&orig.cs[0]) {
|
||||
Some(term(Op::Eq, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Op::Eq if self.is_obliv(&orig.cs[0]) => {
|
||||
trace!(" is oblivious");
|
||||
Some(term(Op::Eq, get_cs()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
@@ -273,7 +217,10 @@ impl RewritePass for Replacer {
|
||||
|
||||
/// Eliminate oblivious arrays. See module documentation.
|
||||
pub fn elim_obliv(t: &mut Computation) {
|
||||
let mut prop_pass = NonOblivComputer::new();
|
||||
let mut prop_pass = NonOblivComputer::default();
|
||||
for o in t.outputs() {
|
||||
prop_pass.mark(o);
|
||||
}
|
||||
prop_pass.traverse(t);
|
||||
let mut replace_pass = Replacer {
|
||||
not_obliv: prop_pass.not_obliv,
|
||||
@@ -281,7 +228,9 @@ pub fn elim_obliv(t: &mut Computation) {
|
||||
let initial_output_sorts: Vec<Sort> = t.outputs.iter().map(check).collect();
|
||||
<Replacer as RewritePass>::traverse(&mut replace_pass, t);
|
||||
for (o, s) in t.outputs.iter_mut().zip(initial_output_sorts) {
|
||||
*o = super::resort(o, &s)
|
||||
if check(o) != s {
|
||||
*o = super::resort(o, &s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,7 +351,9 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn ignore_vars() {
|
||||
let before = text::parse_computation(b"
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
|
||||
(bvadd
|
||||
@@ -410,8 +361,10 @@ mod test {
|
||||
(select (store (#a (bv 8) #x00 4 ()) #x00 a) #x00)
|
||||
(select (store (#a (bv 8) #x01 5 ()) #x00 a) #x01)
|
||||
)
|
||||
)");
|
||||
let expected = text::parse_computation(b"
|
||||
)",
|
||||
);
|
||||
let expected = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
|
||||
(bvadd
|
||||
@@ -419,7 +372,8 @@ mod test {
|
||||
((field 0) ((update 0) (#t #x00 #x00 #x00 #x00) a))
|
||||
((field 1) ((update 0) (#t #x01 #x01 #x01 #x01 #x01) a))
|
||||
)
|
||||
)");
|
||||
)",
|
||||
);
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
@@ -427,35 +381,18 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn preserve_output_type() {
|
||||
let before = text::parse_computation(b"
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
|
||||
(store
|
||||
(store (#a (bv 8) #x00 3 ()) #x01 (select A #x00))
|
||||
#x00 #x05
|
||||
)
|
||||
)");
|
||||
let expected = text::parse_computation(b"
|
||||
(computation
|
||||
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
|
||||
(let ((tupout
|
||||
((update 0)
|
||||
((update 1)
|
||||
(#t #x00 #x00 #x00)
|
||||
(select A #x00)
|
||||
)
|
||||
#x05
|
||||
)
|
||||
))
|
||||
(store
|
||||
(store
|
||||
(store (#a (bv 8) #x00 3 ()) #x00 ((field 0) tupout))
|
||||
#x01 ((field 1) tupout)
|
||||
)
|
||||
#x02 ((field 2) tupout)
|
||||
)
|
||||
)
|
||||
)");
|
||||
)",
|
||||
);
|
||||
let expected = before.clone();
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
@@ -463,25 +400,67 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn identity_fn() {
|
||||
let before = text::parse_computation(b"
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () ((A (array (bv 8) (bv 8) 3))) ())
|
||||
A
|
||||
)");
|
||||
let expected = text::parse_computation(b"
|
||||
(computation
|
||||
(metadata () ((A (array (bv 8) (bv 8) 3))) ())
|
||||
(store
|
||||
(store
|
||||
(store
|
||||
(#a (bv 8) #x00 3 ( ))
|
||||
#x00 (select A #x00))
|
||||
#x01 (select A #x01))
|
||||
#x02 (select A #x02))
|
||||
)");
|
||||
)",
|
||||
);
|
||||
let expected = before.clone();
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignore_call_inputs() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () () ())
|
||||
((call foo (a) ((array (bv 8) (bv 8) 4)) (bool)) (store (#a (bv 8) #x00 4 ()) #x00 #x01))
|
||||
)",
|
||||
);
|
||||
let expected = before.clone();
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignore_call_outputs() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () () ())
|
||||
(select
|
||||
((field 0) ((call foo () () ((array (bv 8) (bv 8) 4))) ))
|
||||
#x00)
|
||||
)",
|
||||
);
|
||||
let expected = before.clone();
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignore_outputs() {
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let before = text::parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata () () ())
|
||||
(store (#a (bv 8) #x00 4 ()) #x00 #x01)
|
||||
)",
|
||||
);
|
||||
let expected = before.clone();
|
||||
let mut after = before.clone();
|
||||
elim_obliv(&mut after);
|
||||
assert_eq!(after, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ pub mod sha;
|
||||
pub mod tuple;
|
||||
mod visit;
|
||||
|
||||
use std::{collections::HashMap, time::Instant};
|
||||
use std::time::Instant;
|
||||
|
||||
use super::term::*;
|
||||
|
||||
@@ -48,6 +48,10 @@ pub enum Opt {
|
||||
pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) -> Functions {
|
||||
for i in optimizations {
|
||||
let mut opt_fs: Functions = fs.clone();
|
||||
if let Opt::InlineCalls = i {
|
||||
link::link_all_function_calls(&mut opt_fs);
|
||||
continue
|
||||
}
|
||||
for (name, comp) in fs.computations.iter_mut() {
|
||||
debug!("Applying: {:?} to {}", i, name);
|
||||
let now = Instant::now();
|
||||
@@ -112,12 +116,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) ->
|
||||
.collect();
|
||||
inline::inline(&mut comp.outputs, &public_inputs);
|
||||
}
|
||||
Opt::InlineCalls => {
|
||||
let mut cache = link::Cache::new();
|
||||
for a in &mut comp.outputs {
|
||||
*a = link::link_function_calls(a.clone(), &mut cache, &opt_fs);
|
||||
}
|
||||
}
|
||||
Opt::InlineCalls => unreachable!(),
|
||||
Opt::Tuple => {
|
||||
tuple::eliminate_tuples(comp);
|
||||
}
|
||||
|
||||
@@ -141,8 +141,10 @@ pub enum Op {
|
||||
/// Map (operation)
|
||||
Map(Box<Op>),
|
||||
|
||||
/// Call a function (name, argument names, argument sorts, return sort)
|
||||
Call(String, Vec<String>, Vec<Sort>, Sort),
|
||||
/// Call a function (name, argument names, argument sorts, return sorts)
|
||||
///
|
||||
/// Note that the type of this term is always a tuple.
|
||||
Call(String, Vec<String>, Vec<Sort>, Vec<Sort>),
|
||||
}
|
||||
|
||||
/// Boolean AND
|
||||
@@ -298,7 +300,7 @@ impl Display for Op {
|
||||
Op::Field(i) => write!(f, "(field {})", i),
|
||||
Op::Update(i) => write!(f, "(update {})", i),
|
||||
Op::Map(op) => write!(f, "(map({}))", op),
|
||||
Op::Call(name, arg_names, arg_sorts, sort) => {
|
||||
Op::Call(name, arg_names, arg_sorts, ret_sorts) => {
|
||||
write!(f, "(call {} (", name)?;
|
||||
for arg_name in arg_names {
|
||||
write!(f, " {}", arg_name)?;
|
||||
@@ -307,7 +309,11 @@ impl Display for Op {
|
||||
for arg_sort in arg_sorts {
|
||||
write!(f, " {}", arg_sort)?;
|
||||
}
|
||||
write!(f, ") {})", sort)
|
||||
write!(f, ") (")?;
|
||||
for ret_sort in ret_sorts {
|
||||
write!(f, " {}", ret_sort)?;
|
||||
}
|
||||
write!(f, "))")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2020,8 +2026,8 @@ impl Functions {
|
||||
}
|
||||
|
||||
/// Get the first computation by function name
|
||||
pub fn get_comp(&self, name: String) -> Option<&Computation> {
|
||||
self.computations.get(&name)
|
||||
pub fn get_comp(&self, name: &str) -> Option<&Computation> {
|
||||
self.computations.get(name)
|
||||
}
|
||||
|
||||
/// Create a computation with a single entry function
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
//! * Operator `O`:
|
||||
//! * Plain operators: (`bvmul`, `and`, ...)
|
||||
//! * Composite operators: `(field N)`, `(update N)`, `(sext N)`, `(uext N)`, `(bit N)`, ...
|
||||
//! * call operator: `(call X (X1 ... XN) (S1 ... SN) S)`
|
||||
//! * call operator: `(call X (X1 ... XN) (S1 ... SN) (RS1 ... RSN))`
|
||||
|
||||
use circ_fields::{FieldT, FieldV};
|
||||
|
||||
@@ -275,12 +275,12 @@ impl<'src> IrInterp<'src> {
|
||||
[Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))),
|
||||
[Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))),
|
||||
[Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))),
|
||||
[Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, sort] => {
|
||||
[Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, ret_sorts] => {
|
||||
let name = from_utf8(name).unwrap().to_owned();
|
||||
let arg_names = self.string_list(arg_names);
|
||||
let arg_sorts = self.sort_list(arg_sorts);
|
||||
let sort = self.sort(sort);
|
||||
Ok(Op::Call(name, arg_names, arg_sorts, sort))
|
||||
let ret_sorts = self.sort_list(ret_sorts);
|
||||
Ok(Op::Call(name, arg_names, arg_sorts, ret_sorts))
|
||||
}
|
||||
_ => todo!("Unparsed op: {}", tt),
|
||||
},
|
||||
@@ -777,7 +777,7 @@ mod test {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare ((a bool))
|
||||
( (call myxor (a b) (bool bool) bool) a a )
|
||||
((field 0) ( (call myxor (a b) (bool bool) (bool)) a a ))
|
||||
)",
|
||||
);
|
||||
assert_eq!(check(&t), Sort::Bool);
|
||||
@@ -918,7 +918,7 @@ mod test {
|
||||
(main
|
||||
(computation
|
||||
(metadata () ((a bool) (b bool)) ())
|
||||
(and false ( (call myxor (a b) (bool bool) bool) a b ))
|
||||
(and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b )))
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::Call(_, _, _, ret) => Ok(ret.clone()),
|
||||
Op::Call(_, _, _, rets) => Ok(Sort::Tuple(rets.clone().into())),
|
||||
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
|
||||
}
|
||||
}
|
||||
@@ -392,14 +392,14 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
|
||||
rec_check_raw_helper(&(*op.clone()), &new_a[..])
|
||||
.map(|val_sort| Sort::Array(Box::new(key_sort), Box::new(val_sort), size))
|
||||
}
|
||||
(Op::Call(_, _, ex_args, ret), act_args) => {
|
||||
(Op::Call(_, _, ex_args, rets), act_args) => {
|
||||
if ex_args.len() != act_args.len() {
|
||||
Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len()))
|
||||
} else {
|
||||
for (e, a) in ex_args.iter().zip(act_args) {
|
||||
eq_or(e, a, "in function call")?;
|
||||
}
|
||||
Ok(ret.clone())
|
||||
Ok(Sort::Tuple(rets.clone().into()))
|
||||
}
|
||||
}
|
||||
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
|
||||
|
||||
@@ -553,11 +553,11 @@ impl<'a> ToABY<'a> {
|
||||
self.term_to_shares.insert(t.clone(), vec![shares[*i]]);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Array);
|
||||
}
|
||||
Op::Call(name, arg_names, _arg_sorts, ret) => {
|
||||
Op::Call(name, arg_names, _arg_sorts, ret_sorts) => {
|
||||
let shares = self.get_shares(&t);
|
||||
let op = format!("CALL({})", name);
|
||||
let num_args = arg_names.len();
|
||||
let num_rets = self.get_sort_len(ret);
|
||||
let num_rets: usize = ret_sorts.iter().map(|ret| self.get_sort_len(ret)).sum();
|
||||
|
||||
// map argument shares
|
||||
let mut arg_shares: Vec<i32> = Vec::new();
|
||||
|
||||
Reference in New Issue
Block a user