Merge branch 'function_calls' of github.com:circify/circ into function_calls

This commit is contained in:
Edward Chen
2022-06-19 16:40:27 -04:00
10 changed files with 450 additions and 369 deletions

View File

@@ -64,8 +64,8 @@ impl FrontEnd for C {
// generate new context
g.circ = Circify::new(Ct::new());
let call = g.function_queue.pop().unwrap();
if let Op::Call(name, arg_names, arg_sorts, rets) = &call.op {
g.fn_call(name, arg_names, arg_sorts, rets);
if let Op::Call(name, arg_names, arg_sorts, ret_sorts) = &call.op {
g.fn_call(name, arg_names, arg_sorts, ret_sorts);
let comp = g.circ.consume().borrow().clone();
// println!("fn: {}", name);
@@ -839,7 +839,7 @@ impl CGen {
name.clone(),
arg_names.clone(),
arg_sorts.clone(),
Sort::Tuple(ret_sorts.clone().into_boxed_slice()),
ret_sorts.clone(),
),
arg_terms.clone().into_iter().flatten().collect::<Vec<_>>(),
);
@@ -1198,7 +1198,7 @@ impl CGen {
name: &String,
arg_names: &Vec<String>,
arg_sorts: &Vec<Sort>,
rets: &Sort,
ret_sorts: &Vec<Sort>,
) {
debug!("Call: {}", name);
println!("Call: {}", name);

View File

@@ -1,167 +1,263 @@
//! Inline function call terms
use std::collections::{BTreeMap, HashMap};
use fxhash::FxHashMap as HashMap;
use crate::ir::term::*;
use crate::ir::opt::visit::RewritePass;
/// Inline cache.
#[derive(Default)]
pub struct Cache(TermMap<Term>);
impl Cache {
/// Empty cache.
pub fn new() -> Self {
Cache(TermMap::new())
}
/// A recursive inliner.
struct Inliner<'f> {
/// Original source for functions
fs: &'f Functions,
/// Map from names to call-free computations
cache: HashMap<String, Computation>,
}
// TODO: this can fail if the function name contains '_'
fn get_var_name(name: &String) -> String {
let new_name = name.to_string().replace('.', "_");
let n = new_name.split('_').collect::<Vec<&str>>();
match n.len() {
5 => n[3].to_string(),
6.. => {
let l = n.len() - 1;
format!("{}_{}", n[l - 2], n[l])
}
_ => {
panic!("Invalid variable name: {}", name);
/// Compute the term that corresponds to a function call.
///
/// ## Arguments
///
/// * `arg_names`: the argument names, in order
/// * `arg_values`: the argument values, in the same order
/// * `callee`: the called function
///
/// ## Returns
///
/// A (tuple) term that corresponds to the output on those inputs
///
/// ## Note
///
/// This function **does not** recursively inline.
fn inline_one(arg_names: &Vec<String>, arg_values: Vec<Term>, callee: &Computation) -> Term {
let mut sub_map: TermMap<Term> = arg_names
.into_iter()
.zip(arg_values)
.map(|(n, v)| {
let s = callee.metadata.input_sort(n).clone();
(leaf_term(Op::Var(n.clone(), s)), v)
})
.collect();
term(
Op::Tuple,
callee
.outputs()
.iter()
.map(|o| extras::substitute_cache(o, &mut sub_map))
.collect(),
)
}
impl<'f> Inliner<'f> {
/// Ensure that a totally inlined version of `name` is in the cache.
fn inline_all(&mut self, name: &str) {
if !self.cache.contains_key(name) {
let mut c = self.fs.get_comp(name).unwrap().clone();
for t in c.terms_postorder() {
if let Op::Call(callee_name, ..) = &t.op {
self.inline_all(callee_name);
}
}
self.traverse(&mut c);
let present = self.cache.insert(name.into(), c);
assert!(present.is_none());
}
}
}
fn match_arg(name: &String, params: &BTreeMap<String, Term>) -> Term {
let new_name = get_var_name(name);
params.get(&new_name).unwrap().clone()
}
fn link(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Term> {
let mut res: Vec<Term> = Vec::new();
let comp = fs.computations.get(name).unwrap();
for o in comp.outputs.iter() {
let mut cache = TermMap::new();
for t in PostOrderIter::new(o.clone()) {
match &t.op {
Op::Var(name, _) => {
let ret = match_arg(name, &params);
cache.insert(t.clone(), ret.clone());
}
_ => {
let mut children = Vec::new();
for c in &t.cs {
if let Some(rewritten_c) = cache.get(c) {
children.push(rewritten_c.clone());
} else {
children.push(c.clone());
}
}
cache.insert(t.clone(), term(t.op.clone(), children));
}
}
/// Rewrites a term, inlining function calls along the way.
///
/// Assumes that the callees are already inlined. Panics otherwise.
impl<'f> RewritePass for Inliner<'f> {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
_computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term> {
if let Op::Call(fn_name, arg_names, _, _) = &orig.op {
let callee = self.cache.get(fn_name).expect("missing inlined callee");
let term = inline_one(arg_names, rewritten_children(), callee);
Some(term)
} else {
None
}
res.push(cache.get(o).unwrap().clone());
}
res
}
/// Traverse terms and link function calls
pub fn link_function_calls(
term_: Term,
Cache(ref mut rewritten): &mut Cache,
fs: &Functions,
) -> Term {
let mut call_cache: HashMap<Term, Vec<Term>> = HashMap::new();
for t in PostOrderIter::new(term_.clone()) {
let mut children = Vec::new();
for c in &t.cs {
if let Some(rewritten_c) = rewritten.get(c) {
children.push(rewritten_c.clone());
} else {
children.push(c.clone());
}
}
let entry = match &t.op {
Op::Field(index) => {
assert!(t.cs.len() > 0);
if let Op::Call(..) = &t.cs[0].op {
if call_cache.contains_key(&t.cs[0]) {
call_cache.get(&t.cs[0]).unwrap()[*index].clone()
} else {
panic!("Fields on a Call term should return");
}
} else {
term(t.op.clone(), children)
}
}
Op::Call(name, arg_names, arg_sorts, _) => {
println!("Inlining: {}", name);
// Check number of args
let num_args = arg_sorts.iter().fold(0, |sum, x| {
sum + match x {
Sort::Array(_, _, l) => *l,
_ => 1,
}
});
assert!(
num_args == t.cs.len(),
"Number of arguments mismatch. {}, {}",
num_args,
t.cs.len()
);
// Check arg types
let arg_types = arg_sorts
.iter()
.map(|x| match &x {
Sort::Array(_, val_sort, l) => {
let mut res: Vec<Sort> = Vec::new();
for _ in 0..*l {
res.push(*val_sort.clone());
}
res
}
_ => vec![x.clone()],
})
.flatten()
.collect::<Vec<_>>();
assert!(
arg_types == t.cs.iter().map(|c| check(c)).collect::<Vec<Sort>>(),
"Argument type mismatch"
);
let mut params: BTreeMap<String, Term> = BTreeMap::new();
let arg_keys = arg_names
.iter()
.zip(arg_sorts.iter())
.map(|(n, x)| match &x {
Sort::Array(_, _, l) => {
let mut res: Vec<String> = Vec::new();
for i in 0..*l {
res.push(format!("{}_{}", n.clone(), i));
}
res
}
_ => vec![n.clone()],
})
.flatten();
for (n, c) in arg_keys.zip(t.cs.clone()) {
params.insert(n.clone(), c.clone());
}
let res = link(name, params, fs);
call_cache.insert(t.clone(), res.clone());
res[0].clone()
}
_ => term(t.op.clone(), children),
};
rewritten.insert(t.clone(), entry);
}
if let Some(t) = rewritten.get(&term_) {
t.clone()
} else {
panic!("Couldn't find rewritten binarized term: {}", term_);
}
}
/// Inline all calls within a function set.
pub fn link_all_function_calls(fs: &mut Functions) {
let mut inliner = Inliner {
fs,
cache: Default::default(),
};
for name in fs.computations.keys() {
inliner.inline_all(name);
}
*fs = Functions {
computations: inliner.cache.into_iter().collect(),
};
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn bool_arg_nonrec() {
let mut fs = text::parse_functions(
b"
(functions
(
(myxor
(computation
(metadata () ((a bool) (b bool)) ())
(xor a b false false)
)
)
(main
(computation
(metadata () ((a bool) (b bool)) ())
(and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b )))
)
)
)
)",
);
let expected = text::parse_computation(
b"
(computation
(metadata () ((a bool) (b bool)) ())
(and false ((field 0) (tuple (xor a b false false))))
)
",
);
link_all_function_calls(&mut fs);
let c = fs.get_comp("main").unwrap().clone();
assert_eq!(c, expected);
}
#[test]
fn scalar_arg_nonrec() {
let mut fs = text::parse_functions(
b"
(functions
(
(myxor
(computation
(metadata () ((a bool) (b (bv 4))) ())
(bvxor (ite a #x0 #x1) b)
)
)
(main
(computation
(metadata () ((c bool)) ())
(bvand #xf ((field 0) ( (call myxor (a b) (bool (bv 4)) ((bv 4))) c #x4 )))
)
)
)
)",
);
let expected = text::parse_computation(
b"
(computation
(metadata () ((c bool)) ())
(bvand #xf ((field 0) (tuple (bvxor (ite c #x0 #x1) #x4))))
)
",
);
link_all_function_calls(&mut fs);
let c = fs.get_comp("main").unwrap().clone();
assert_eq!(c, expected);
}
#[test]
fn nested_calls() {
let mut fs = text::parse_functions(
b"
(functions
(
(foo
(computation
(metadata () ((a bool)) ())
(not a)
)
)
(bar
(computation
(metadata () ((a bool)) ())
(xor ((field 0) ((call foo (a) (bool) (bool)) a)) true)
)
)
(main
(computation
(metadata () ((a bool) (b bool)) ())
((field 0) ((call bar (a) (bool) (bool)) a))
)
)
)
)",
);
let expected = text::parse_computation(
b"
(computation
(metadata () ((a bool) (b bool)) ())
((field 0) (tuple (xor ((field 0) (tuple (not a))) true)))
)
",
);
link_all_function_calls(&mut fs);
let c = fs.get_comp("main").unwrap().clone();
assert_eq!(c, expected);
}
#[test]
fn multiple_calls() {
let mut fs = text::parse_functions(
b"
(functions
(
(foo
(computation
(metadata () ((a bool)) ())
(not a)
)
)
(bar
(computation
(metadata () ((a bool)) ())
(xor ((field 0) ((call foo (a) (bool) (bool)) a)) true)
)
)
(main
(computation
(metadata () ((a bool) (b bool)) ())
(and
((field 0) ((call foo (a) (bool) (bool)) a))
((field 0) ((call foo (a) (bool) (bool)) b))
((field 0) ((call bar (a) (bool) (bool)) a))
)
)
)
)
)",
);
let expected = text::parse_computation(
b"
(computation
(metadata () ((a bool) (b bool)) ())
(and
((field 0) (tuple (not a)))
((field 0) (tuple (not b)))
((field 0) (tuple (xor ((field 0) (tuple (not a))) true)))
)
)
",
);
link_all_function_calls(&mut fs);
let c = fs.get_comp("main").unwrap().clone();
assert_eq!(c, expected);
}
}

View File

@@ -26,14 +26,14 @@ struct Linearizer;
impl RewritePass for Linearizer {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
_computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term> {
match &orig.op {
Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(super::arr_val_to_tup(v)))),
Op::Var(_name, s) if super::sort_contains_array(s) => Some(super::array_to_tuple(orig)),
Op::Call(_name, _arg_names, arg_sorts, ret_sort) => {
Op::Call(_name, _arg_names, arg_sorts, ret_sorts) => {
let mut args = rewritten_children();
for (a, s) in args.iter_mut().zip(arg_sorts) {
if super::sort_contains_array(s) {
@@ -41,7 +41,7 @@ impl RewritePass for Linearizer {
}
}
let out = term(orig.op.clone(), args);
Some(if super::sort_contains_array(ret_sort) {
Some(if ret_sorts.iter().any(super::sort_contains_array) {
super::array_to_tuple(&out)
} else {
out

View File

@@ -59,6 +59,7 @@ fn sort_contains_array(s: &Sort) -> bool {
/// Given a sort `s` which may contain array constructors, construct a new sort in which the arrays
/// have been flattened to tuples.
#[allow(dead_code)]
fn array_to_tuple_sort(s: &Sort) -> Sort {
match s {
Sort::Tuple(ss) => Sort::Tuple(ss.iter().map(array_to_tuple_sort).collect()),
@@ -70,7 +71,7 @@ fn array_to_tuple_sort(s: &Sort) -> Sort {
/// Given a term of tuples, re-shape into sort `s`.
fn resort(t: &Term, s: &Sort) -> Term {
match s {
Sort::Array(k, v, sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold(
Sort::Array(k, v, _sz) => extras::tuple_or_array_elements(t).zip(k.elems_iter()).fold(
s.default_term(),
|acc, (t, i)| term![Op::Store; acc, i, resort(&t, v)],
),

View File

@@ -11,24 +11,22 @@
//!
//! ## Pass 1: Identifying oblivious arrays
//!
//!
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole computation
//! system, performing the following inferences:
//! system, marking som arrays as non-oblivious. Only non-constant terms with a sort that contains
//! an array can be non-oblivious.
//!
//! * If `a[i]` for non-constant `i`, then `a` and `a[i]` are not oblivious;
//! * If `a[i]`, `a` and `a[i]` are equi-oblivious
//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious
//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious
//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious
//! * If `a=b`, then `a` and `b` are equi-oblivious
//! * If any other term is array-valued, that array is non-oblivious. Including:
//! * variables
//! * function call outputs
//! * tuple fields
//!
//! This procedure is iterated to fixpoint.
//! The following are *directly marked* as non-oblivious:
//!
//! Notice that we flag some *array* terms as non-oblivious, and we also flag their derived select
//! terms as non-oblivious. This makes it easy to see which selects should be replaced later.
//! * Selects with non-constant indices
//! * Stores with non-constant indices
//! * Any term containing an array sort that is not a array of primitives
//! * Variables, call arguments, call outputs, and the computation outputs
//!
//! All terms propagate non-obliviousness to their children and recieve it from their children.
//!
//! Propagation is repeated to fixpoint.
//!
//! ### Sharing & Constant Arrays
//!
@@ -69,130 +67,89 @@
//! * map array terms to tuple terms
//! * map array selections to tuple field gets
//!
//! In both cases we look at the non-oblivious array/select set to see whether to do the
//! replacement.
//!
//! In both cases we look at the non-oblivious array set to see whether to do the replacement.
use super::super::visit::*;
use crate::ir::term::extras::as_uint_constant;
use crate::ir::term::*;
use log::{debug, trace};
use log::trace;
fn is_prim_array(t: &Sort) -> bool {
if let Sort::Array(k, v, _) = t {
k.is_scalar() && v.is_scalar()
} else {
false
}
}
fn is_markable(t: &Term, s: &Sort) -> bool {
!t.is_const() && !s.is_scalar() && super::sort_contains_array(s)
}
#[derive(Default)]
struct NonOblivComputer {
not_obliv: TermSet,
}
impl NonOblivComputer {
fn mark(&mut self, a: &Term) -> bool {
if !a.is_const() && self.not_obliv.insert(a.clone()) {
trace!("Not obliv: {}", a);
let s = check(a);
if is_markable(a, &s) && self.not_obliv.insert(a.clone()) {
trace!("Not obliv: {}", extras::Letified(a.clone()));
true
} else {
false
}
}
fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool {
if !a.is_const() && !b.is_const() {
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
(false, true) => {
self.not_obliv.insert(a.clone());
true
}
(true, false) => {
self.not_obliv.insert(b.clone());
true
}
_ => false,
fn propagate(&mut self, a: &Term) -> bool {
if self.not_obliv.contains(a) || a.cs.iter().any(|c| self.not_obliv.contains(c)) {
let mut progress = false;
progress |= self.mark(a);
for c in &a.cs {
progress |= self.mark(c);
}
progress
} else {
false
}
}
fn new() -> Self {
Self {
not_obliv: TermSet::new(),
}
}
}
impl ProgressAnalysisPass for NonOblivComputer {
fn visit(&mut self, term: &Term) -> bool {
match &term.op {
Op::Store => {
let a = &term.cs[0];
let i = &term.cs[1];
let v = &term.cs[2];
let mut progress = false;
if let Sort::Array(..) = check(v) {
// Imprecisely, mark v as non-obliv iff the array is.
progress = self.bi_implicate(term, v) || progress;
}
if i.is_const() {
progress = self.bi_implicate(term, a) || progress;
} else {
progress = self.mark(a) || progress;
progress = self.mark(term) || progress;
}
if let Sort::Array(..) = check(v) {
// Imprecisely, mark v as non-obliv iff the array is.
progress = self.bi_implicate(term, v) || progress;
}
progress
}
Op::Select => {
// Even though the selected value may not have array sort, we still flag it as
// non-oblivious so we know whether to replace it or not.
let a = &term.cs[0];
let i = &term.cs[1];
let mut progress = false;
if i.is_const() {
// pass
} else {
progress = self.mark(a) || progress;
progress = self.mark(term) || progress;
}
progress = self.bi_implicate(term, a) || progress;
progress
}
Op::Ite => {
let t = &term.cs[1];
let f = &term.cs[2];
if let Sort::Array(..) = check(t) {
let mut progress = self.bi_implicate(term, t);
progress = self.bi_implicate(t, f) || progress;
progress = self.bi_implicate(term, f) || progress;
progress
} else {
false
}
}
Op::Eq => {
let a = &term.cs[0];
let b = &term.cs[1];
if let Sort::Array(..) = check(a) {
self.bi_implicate(a, b)
} else {
false
}
}
// constants are oblivious
Op::Const(..) => false,
_ => match check(term) {
Sort::Array(..) => {
// variables, fields, and function outputs are non-oblivious.
debug_assert!(
matches!(term.op, Op::Call(..) | Op::Var(..) | Op::Field(..)),
"Unexpected array term {}",
term
);
self.mark(term)
}
let sort = check(term);
let mut progress = false;
// First, we directly mark this term if
// (a) its sort contains an array and is not an array of primitives.
let do_mark = if !sort.is_scalar() && !is_prim_array(&sort) {
true
} else {
// (b) it is a select or store with a non-constant index OR a call
match &term.op {
Op::Store if !term.cs[1].is_const() => true,
Op::Select if !term.cs[1].is_const() => true,
Op::Var(..) => true,
Op::Call(..) => true,
_ => false,
},
}
};
if do_mark {
progress |= self.mark(term);
}
// Now, we mark the children of calls
if let Op::Call(..) = &term.op {
for c in &term.cs {
progress |= self.mark(c);
}
}
// Finally, we propagate marks
progress |= self.propagate(term);
progress
}
}
@@ -202,7 +159,7 @@ struct Replacer {
}
impl Replacer {
fn should_replace(&self, a: &Term) -> bool {
fn is_obliv(&self, a: &Term) -> bool {
!self.not_obliv.contains(a)
}
}
@@ -229,42 +186,29 @@ impl RewritePass for Replacer {
.map(super::term_arr_val_to_tup)
.collect()
};
debug!("visiting {}", orig);
trace!("rewriting {}", extras::Letified(orig.clone()));
match &orig.op {
Op::Select => {
// we mark the selected term as non-obliv...
if self.should_replace(orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 2);
let k_const = get_const(&cs.pop().unwrap());
Some(term(Op::Field(k_const), cs))
} else {
None
}
Op::Select if self.is_obliv(&orig.cs[0]) => {
trace!(" is oblivious");
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 2);
let k_const = get_const(&cs.pop().unwrap());
Some(term(Op::Field(k_const), cs))
}
Op::Store => {
if self.should_replace(orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 3);
let k_const = get_const(&cs.remove(1));
Some(term(Op::Update(k_const), cs))
} else {
None
}
Op::Store if self.is_obliv(orig) => {
trace!(" is oblivious");
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 3);
let k_const = get_const(&cs.remove(1));
Some(term(Op::Update(k_const), cs))
}
Op::Ite => {
if self.should_replace(orig) {
Some(term(Op::Ite, get_cs()))
} else {
None
}
Op::Ite if self.is_obliv(orig) => {
trace!(" is oblivious");
Some(term(Op::Ite, get_cs()))
}
Op::Eq => {
if self.should_replace(&orig.cs[0]) {
Some(term(Op::Eq, get_cs()))
} else {
None
}
Op::Eq if self.is_obliv(&orig.cs[0]) => {
trace!(" is oblivious");
Some(term(Op::Eq, get_cs()))
}
_ => None,
}
@@ -273,7 +217,10 @@ impl RewritePass for Replacer {
/// Eliminate oblivious arrays. See module documentation.
pub fn elim_obliv(t: &mut Computation) {
let mut prop_pass = NonOblivComputer::new();
let mut prop_pass = NonOblivComputer::default();
for o in t.outputs() {
prop_pass.mark(o);
}
prop_pass.traverse(t);
let mut replace_pass = Replacer {
not_obliv: prop_pass.not_obliv,
@@ -281,7 +228,9 @@ pub fn elim_obliv(t: &mut Computation) {
let initial_output_sorts: Vec<Sort> = t.outputs.iter().map(check).collect();
<Replacer as RewritePass>::traverse(&mut replace_pass, t);
for (o, s) in t.outputs.iter_mut().zip(initial_output_sorts) {
*o = super::resort(o, &s)
if check(o) != s {
*o = super::resort(o, &s)
}
}
}
@@ -402,7 +351,9 @@ mod test {
#[test]
fn ignore_vars() {
let before = text::parse_computation(b"
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
(bvadd
@@ -410,8 +361,10 @@ mod test {
(select (store (#a (bv 8) #x00 4 ()) #x00 a) #x00)
(select (store (#a (bv 8) #x01 5 ()) #x00 a) #x01)
)
)");
let expected = text::parse_computation(b"
)",
);
let expected = text::parse_computation(
b"
(computation
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
(bvadd
@@ -419,7 +372,8 @@ mod test {
((field 0) ((update 0) (#t #x00 #x00 #x00 #x00) a))
((field 1) ((update 0) (#t #x01 #x01 #x01 #x01 #x01) a))
)
)");
)",
);
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
@@ -427,35 +381,18 @@ mod test {
#[test]
fn preserve_output_type() {
let before = text::parse_computation(b"
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
(store
(store (#a (bv 8) #x00 3 ()) #x01 (select A #x00))
#x00 #x05
)
)");
let expected = text::parse_computation(b"
(computation
(metadata () ((a (bv 8)) (A (array (bv 8) (bv 8) 3))) ())
(let ((tupout
((update 0)
((update 1)
(#t #x00 #x00 #x00)
(select A #x00)
)
#x05
)
))
(store
(store
(store (#a (bv 8) #x00 3 ()) #x00 ((field 0) tupout))
#x01 ((field 1) tupout)
)
#x02 ((field 2) tupout)
)
)
)");
)",
);
let expected = before.clone();
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
@@ -463,25 +400,67 @@ mod test {
#[test]
fn identity_fn() {
let before = text::parse_computation(b"
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () ((A (array (bv 8) (bv 8) 3))) ())
A
)");
let expected = text::parse_computation(b"
(computation
(metadata () ((A (array (bv 8) (bv 8) 3))) ())
(store
(store
(store
(#a (bv 8) #x00 3 ( ))
#x00 (select A #x00))
#x01 (select A #x01))
#x02 (select A #x02))
)");
)",
);
let expected = before.clone();
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
}
#[test]
fn ignore_call_inputs() {
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () () ())
((call foo (a) ((array (bv 8) (bv 8) 4)) (bool)) (store (#a (bv 8) #x00 4 ()) #x00 #x01))
)",
);
let expected = before.clone();
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
}
#[test]
fn ignore_call_outputs() {
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () () ())
(select
((field 0) ((call foo () () ((array (bv 8) (bv 8) 4))) ))
#x00)
)",
);
let expected = before.clone();
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
}
#[test]
fn ignore_outputs() {
let _ = env_logger::builder().is_test(true).try_init();
let before = text::parse_computation(
b"
(computation
(metadata () () ())
(store (#a (bv 8) #x00 4 ()) #x00 #x01)
)",
);
let expected = before.clone();
let mut after = before.clone();
elim_obliv(&mut after);
assert_eq!(after, expected);
}
}

View File

@@ -10,7 +10,7 @@ pub mod sha;
pub mod tuple;
mod visit;
use std::{collections::HashMap, time::Instant};
use std::time::Instant;
use super::term::*;
@@ -48,6 +48,10 @@ pub enum Opt {
pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) -> Functions {
for i in optimizations {
let mut opt_fs: Functions = fs.clone();
if let Opt::InlineCalls = i {
link::link_all_function_calls(&mut opt_fs);
continue
}
for (name, comp) in fs.computations.iter_mut() {
debug!("Applying: {:?} to {}", i, name);
let now = Instant::now();
@@ -112,12 +116,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) ->
.collect();
inline::inline(&mut comp.outputs, &public_inputs);
}
Opt::InlineCalls => {
let mut cache = link::Cache::new();
for a in &mut comp.outputs {
*a = link::link_function_calls(a.clone(), &mut cache, &opt_fs);
}
}
Opt::InlineCalls => unreachable!(),
Opt::Tuple => {
tuple::eliminate_tuples(comp);
}

View File

@@ -141,8 +141,10 @@ pub enum Op {
/// Map (operation)
Map(Box<Op>),
/// Call a function (name, argument names, argument sorts, return sort)
Call(String, Vec<String>, Vec<Sort>, Sort),
/// Call a function (name, argument names, argument sorts, return sorts)
///
/// Note that the type of this term is always a tuple.
Call(String, Vec<String>, Vec<Sort>, Vec<Sort>),
}
/// Boolean AND
@@ -298,7 +300,7 @@ impl Display for Op {
Op::Field(i) => write!(f, "(field {})", i),
Op::Update(i) => write!(f, "(update {})", i),
Op::Map(op) => write!(f, "(map({}))", op),
Op::Call(name, arg_names, arg_sorts, sort) => {
Op::Call(name, arg_names, arg_sorts, ret_sorts) => {
write!(f, "(call {} (", name)?;
for arg_name in arg_names {
write!(f, " {}", arg_name)?;
@@ -307,7 +309,11 @@ impl Display for Op {
for arg_sort in arg_sorts {
write!(f, " {}", arg_sort)?;
}
write!(f, ") {})", sort)
write!(f, ") (")?;
for ret_sort in ret_sorts {
write!(f, " {}", ret_sort)?;
}
write!(f, "))")
}
}
}
@@ -2020,8 +2026,8 @@ impl Functions {
}
/// Get the first computation by function name
pub fn get_comp(&self, name: String) -> Option<&Computation> {
self.computations.get(&name)
pub fn get_comp(&self, name: &str) -> Option<&Computation> {
self.computations.get(name)
}
/// Create a computation with a single entry function

View File

@@ -46,7 +46,7 @@
//! * Operator `O`:
//! * Plain operators: (`bvmul`, `and`, ...)
//! * Composite operators: `(field N)`, `(update N)`, `(sext N)`, `(uext N)`, `(bit N)`, ...
//! * call operator: `(call X (X1 ... XN) (S1 ... SN) S)`
//! * call operator: `(call X (X1 ... XN) (S1 ... SN) (RS1 ... RSN))`
use circ_fields::{FieldT, FieldV};
@@ -275,12 +275,12 @@ impl<'src> IrInterp<'src> {
[Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))),
[Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))),
[Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))),
[Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, sort] => {
[Leaf(Ident, b"call"), Leaf(Ident, name), arg_names, arg_sorts, ret_sorts] => {
let name = from_utf8(name).unwrap().to_owned();
let arg_names = self.string_list(arg_names);
let arg_sorts = self.sort_list(arg_sorts);
let sort = self.sort(sort);
Ok(Op::Call(name, arg_names, arg_sorts, sort))
let ret_sorts = self.sort_list(ret_sorts);
Ok(Op::Call(name, arg_names, arg_sorts, ret_sorts))
}
_ => todo!("Unparsed op: {}", tt),
},
@@ -777,7 +777,7 @@ mod test {
let t = parse_term(
b"
(declare ((a bool))
( (call myxor (a b) (bool bool) bool) a a )
((field 0) ( (call myxor (a b) (bool bool) (bool)) a a ))
)",
);
assert_eq!(check(&t), Sort::Bool);
@@ -918,7 +918,7 @@ mod test {
(main
(computation
(metadata () ((a bool) (b bool)) ())
(and false ( (call myxor (a b) (bool bool) bool) a b ))
(and false ((field 0) ( (call myxor (a b) (bool bool) (bool)) a b )))
)
)
)

View File

@@ -175,7 +175,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
}
}
}
Op::Call(_, _, _, ret) => Ok(ret.clone()),
Op::Call(_, _, _, rets) => Ok(Sort::Tuple(rets.clone().into())),
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
}
}
@@ -392,14 +392,14 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
rec_check_raw_helper(&(*op.clone()), &new_a[..])
.map(|val_sort| Sort::Array(Box::new(key_sort), Box::new(val_sort), size))
}
(Op::Call(_, _, ex_args, ret), act_args) => {
(Op::Call(_, _, ex_args, rets), act_args) => {
if ex_args.len() != act_args.len() {
Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len()))
} else {
for (e, a) in ex_args.iter().zip(act_args) {
eq_or(e, a, "in function call")?;
}
Ok(ret.clone())
Ok(Sort::Tuple(rets.clone().into()))
}
}
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),

View File

@@ -553,11 +553,11 @@ impl<'a> ToABY<'a> {
self.term_to_shares.insert(t.clone(), vec![shares[*i]]);
self.cache.insert(t.clone(), EmbeddedTerm::Array);
}
Op::Call(name, arg_names, _arg_sorts, ret) => {
Op::Call(name, arg_names, _arg_sorts, ret_sorts) => {
let shares = self.get_shares(&t);
let op = format!("CALL({})", name);
let num_args = arg_names.len();
let num_rets = self.get_sort_len(ret);
let num_rets: usize = ret_sorts.iter().map(|ret| self.get_sort_len(ret)).sum();
// map argument shares
let mut arg_shares: Vec<i32> = Vec::new();