fixed return type for calls

This commit is contained in:
Edward Chen
2022-05-24 21:32:42 -04:00
parent 094b45e76b
commit f29eec8d7b
11 changed files with 191 additions and 167 deletions

View File

@@ -6,7 +6,7 @@ int fa(int * c, int a) {
}
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) {
int c[5];
int c[5] = {0,1,2,3,4};
int ret = fa(c, a);
int sum = ret;
for (int i = 0; i < 5; i++) {

View File

@@ -247,6 +247,8 @@ fn main() {
Opt::Sha,
Opt::ConstantFold(Box::new(ignore.clone())),
Opt::Flatten,
// The function call abstraction creates tuples
Opt::Tuple,
Opt::Obliv,
// The obliv elim pass produces more tuples, that must be eliminated
Opt::Tuple,
@@ -254,10 +256,10 @@ fn main() {
// The linear scan pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::ConstantFold(Box::new(ignore.clone())),
// Inline Function Calls
Opt::InlineCalls,
// Binarize nary terms
Opt::Binarize,
// // Inline Function Calls
// Opt::InlineCalls,
],
)
}
@@ -286,21 +288,21 @@ fn main() {
),
};
// for (name, comp) in cs.computations.iter() {
// println!("functions: {}", name);
// for t in &comp.outputs {
// println!("function term: {}, {}", t, t.uid());
// for t1 in PostOrderIter::new(t.clone()) {
// println!("term: {}, {}", t1, t1.uid());
// for c in t1.cs.iter() {
// println!("children: {}, {}", c, c.uid());
// }
// println!();
// }
// println!();
// }
// println!("\n");
// }
for (name, comp) in cs.computations.iter() {
println!("functions: {}", name);
for t in &comp.outputs {
println!("function term: {}, {}", t, t.uid());
// for t1 in PostOrderIter::new(t.clone()) {
// println!("term: {}, {}", t1, t1.uid());
// for c in t1.cs.iter() {
// println!("children: {}, {}", c, c.uid());
// }
// println!();
// }
println!();
}
println!("\n");
}
println!("Done with IR optimization");

View File

@@ -30,10 +30,10 @@ function mpc_test_2 {
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+b"
}
# mpc_test_2 2 ./examples/C/mpc/playground.c
mpc_test_2 2 ./examples/C/mpc/playground.c
# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
# # build mpc arithmetic tests
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c

View File

@@ -775,27 +775,21 @@ impl<E: Embeddable> Circify<E> {
/// ## Returns
///
/// Returns the return value of the function, if any.
pub fn exit_fn_call(&mut self, ret_names: &Vec<&String>) -> HashMap<String, Val<E::T>> {
pub fn exit_fn_call(&mut self, ret_names: &Vec<String>) -> Option<Vec<Val<E::T>>> {
if let Some(fn_) = self.fn_stack.last() {
let mut rets: HashMap<String, Val<E::T>> = HashMap::new();
let mut rets: Vec<Val<E::T>> = Vec::new();
// Get return value if possible
if fn_.has_return {
rets.insert(
RET_NAME.to_string(),
self.get_value(Loc::local(RET_NAME.to_owned())).unwrap(),
);
rets.push(self.get_value(Loc::local(RET_NAME.to_owned())).unwrap());
}
// Get references if possible
for name in ret_names {
rets.insert(
name.to_string(),
self.get_value(Loc::local(name.to_string())).unwrap(),
);
rets.push(self.get_value(Loc::local(name.to_string())).unwrap());
}
self.fn_stack.pop().unwrap();
rets
Some(rets)
} else {
panic!("No fn to exit")
}

View File

@@ -118,7 +118,7 @@ pub fn body_from_func(fn_def: &FunctionDefinition) -> Statement {
pub fn fn_info_to_defs(
fn_info: &FnInfo,
arg_terms: &Vec<Vec<Term>>, // arguments taken at call site
) -> (String, BTreeMap<String, Sort>, BTreeMap<String, Sort>) {
) -> (String, Vec<String>, Vec<Sort>, BTreeMap<String, Sort>) {
let mut rets: BTreeMap<String, Sort> = BTreeMap::new();
match &fn_info.ret_ty {
Ty::Void => {}
@@ -127,7 +127,8 @@ pub fn fn_info_to_defs(
}
};
assert!(fn_info.params.len() == arg_terms.len());
let mut params = BTreeMap::new();
let mut param_names: Vec<String> = Vec::new();
let mut param_sorts: Vec<Sort> = Vec::new();
for (param, arg) in fn_info.params.iter().zip(arg_terms.iter()) {
let name = param.name.clone();
let ty = match &param.ty {
@@ -140,10 +141,10 @@ pub fn fn_info_to_defs(
}
_ => param.ty.sort(),
};
params.insert(name, ty.clone());
param_names.push(name);
param_sorts.push(ty.clone());
}
(fn_info.name.clone(), params, rets)
(fn_info.name.clone(), param_names, param_sorts, rets)
}
pub fn flatten_inits(init: Initializer) -> Vec<Initializer> {

View File

@@ -78,8 +78,8 @@ impl FrontEnd for C {
// generate new context
g.circ = Circify::new(Ct::new(i.inputs.clone().map(parser::parse_inputs)));
let call = g.function_queue.pop().unwrap();
if let Op::Call(name, args, rets, ret_name) = &call.op {
g.fn_call(name, args, rets, ret_name);
if let Op::Call(name, arg_names, arg_sorts, rets) = &call.op {
g.fn_call(name, arg_names, arg_sorts, rets);
let comp = g.circ.consume().borrow().clone();
// println!("fn: {}", name);
@@ -826,6 +826,7 @@ impl CGen {
let ret_ty = f.ret_ty.clone();
let mut arg_names: Vec<String> = Vec::new();
let cargs = arguments
.iter()
.map(|e| self.gen_expr(e.node.clone()))
@@ -833,20 +834,26 @@ impl CGen {
let mut cargs_map: HashMap<String, CTerm> = HashMap::new();
for (p, c) in f.params.iter().zip(cargs.iter()) {
cargs_map.insert(p.name.clone(), c.clone());
arg_names.push(p.name.clone());
}
let arg_terms = cargs
.iter()
.map(|e| e.term.terms(self.circ.cir_ctx()))
.collect::<Vec<_>>();
let flatten_args = arg_terms.clone().into_iter().flatten().collect::<Vec<_>>();
let (name, args, rets) = fn_info_to_defs(&f, &arg_terms);
let (name, arg_names, arg_sorts, rets) = fn_info_to_defs(&f, &arg_terms);
let call_term = term(
Op::Call(
name.clone(),
args.clone(),
rets.clone(),
"return".to_string(),
arg_names.clone(),
arg_sorts.clone(),
Sort::Tuple(
rets.values()
.cloned()
.collect::<Vec<_>>()
.into_boxed_slice(),
),
),
flatten_args.clone(),
);
@@ -859,58 +866,71 @@ impl CGen {
// Rewiring
for (ret_name, sort) in rets.iter() {
if ret_name != "return" {
let call = term(
Op::Call(
name.clone(),
args.clone(),
rets.clone(),
ret_name.to_string(),
),
flatten_args.clone(),
);
if let Sort::Array(_, _, l) = sort {
let ct = cargs_map.get(ret_name).unwrap();
if let CTermData::Array(_, id) = ct.term {
self.circ.replace(id.unwrap(), call.clone());
// self.circ.assign(l, Val::Term(val));
// for i in 0..*l {
// let updated_idx = bv_lit(i as i32, 32);
// // TODO: index calculation
// self.circ.store(id.unwrap(), updated_idx, call.clone());
// }
} else {
unimplemented!("This should only be handling ptrs to arrays");
}
if let Sort::Array(_, _, l) = sort {
let ct = cargs_map.get(ret_name).unwrap();
if let CTermData::Array(_, id) = ct.term {
self.circ.replace(id.unwrap(), call_term.clone());
} else {
unimplemented!("This should only be handling ptrs to arrays");
}
// println!("CT: {}", ct.term.term());
// // self
// // .circ
// // .assign(l, Val::Term(val))
// // .map_err(|e| format!("{}", e))?
// // .unwrap_term()
// unimplemented!();
// // if let CTermData::Array(_, id) = ct.term {
// // }
// // for i in 0..*l {
// // let updated_idx = bv_lit(i as i32, 32);
// // self.circ.store(id.unwrap(), updated_idx, call.clone());
// // }
// // } else {
// // unimplemented!("This should only be handling ptrs to arrays");
// // }
// } else {
// unimplemented!("This should only be handling ptrs to arrays");
// }
}
}
// // Rewiring
// for (ret_name, sort) in rets.iter() {
// println!("retname: {}", ret_name);
// if ret_name != "return" {
// let call = term(
// Op::Call(
// name.clone(),
// args.clone(),
// rets.clone(),
// ret_name.to_string(),
// ),
// flatten_args.clone(),
// );
// if let Sort::Array(_, _, l) = sort {
// let ct = cargs_map.get(ret_name).unwrap();
// if let CTermData::Array(_, id) = ct.term {
// self.circ.replace(id.unwrap(), call.clone());
// // self.circ.assign(l, Val::Term(val));
// // for i in 0..*l {
// // let updated_idx = bv_lit(i as i32, 32);
// // // TODO: index calculation
// // self.circ.store(id.unwrap(), updated_idx, call.clone());
// // }
// } else {
// unimplemented!("This should only be handling ptrs to arrays");
// }
// } else {
// unimplemented!("This should only be handling ptrs to arrays");
// }
// // println!("CT: {}", ct.term.term());
// // // self
// // // .circ
// // // .assign(l, Val::Term(val))
// // // .map_err(|e| format!("{}", e))?
// // // .unwrap_term()
// // unimplemented!();
// // // if let CTermData::Array(_, id) = ct.term {
// // // }
// // // for i in 0..*l {
// // // let updated_idx = bv_lit(i as i32, 32);
// // // self.circ.store(id.unwrap(), updated_idx, call.clone());
// // // }
// // // } else {
// // // unimplemented!("This should only be handling ptrs to arrays");
// // // }
// // } else {
// // unimplemented!("This should only be handling ptrs to arrays");
// // }
// }
// }
// Return value
let ret = match ret_ty {
Ty::Void | Ty::Bool => cterm(CTermData::Bool(call_term)),
@@ -1286,18 +1306,17 @@ impl CGen {
fn fn_call(
&mut self,
name: &String,
args: &BTreeMap<String, Sort>,
rets: &BTreeMap<String, Sort>,
ret_name: &String,
arg_names: &Vec<String>,
arg_sorts: &Vec<Sort>,
rets: &Sort,
) {
debug!("Call: {}", name);
println!("Call: {}", name);
// for (n, a) in args {
// println!("args: {}, {}", n, a);
// }
// for (r, s) in rets.iter() {
// println!("ret: {}, {}", r, s);
// }
let mut arg_map: BTreeMap<String, Sort> = BTreeMap::new();
for (n, s) in arg_names.iter().zip(arg_sorts.iter()) {
arg_map.insert(n.to_string(), s.clone());
}
// Get function types
let f = self
@@ -1313,16 +1332,23 @@ impl CGen {
};
self.circ.enter_fn(name.to_owned(), ret_ty);
// Keep track of the names of arguments that are references
let mut ret_names: Vec<String> = Vec::new();
// define input parameters
assert!(args.len() == f.params.len());
assert!(arg_map.len() == f.params.len());
for param in f.params {
let p_name = param.name;
assert!(args.contains_key(&p_name));
let s = args.get(&p_name).unwrap();
assert!(arg_map.contains_key(&p_name));
let s = arg_map.get(&p_name).unwrap();
let p_ty = match param.ty {
Ty::Ptr(_, t) => {
if let Sort::Array(_, _, len) = s {
let dims = vec![*len];
// Add reference
ret_names.push(p_name.clone());
Ty::Array(*len, dims, t)
} else {
panic!("Ptr type does not match with Array sort: {}", s)
@@ -1336,10 +1362,34 @@ impl CGen {
self.gen_stmt(f.body.clone());
// let ret_names = &rets.keys().collect::<Vec<&String>>();
// let rets = self.circ.exit_fn_call(ret_names);
// for (name, val) in rets {
// let ret_terms = val.unwrap_term().term.terms(self.circ.cir_ctx());
if let Some(returns) = self.circ.exit_fn_call(&ret_names) {
let ret_terms = returns
.into_iter()
.map(|x| x.unwrap_term().term.terms(self.circ.cir_ctx()))
.flatten()
.collect::<Vec<Term>>();
self.circ
.cir_ctx()
.cs
.borrow_mut()
.outputs
.push(term(Op::Tuple, ret_terms));
}
// for (name, val) in returns {
// println!("name: {}", name);
// // let ret_terms = val.unwrap_term().term.terms(self.circ.cir_ctx());
// // self.circ
// // .cir_ctx()
// // .cs
// // .borrow_mut()
// // .outputs
// // .extend(ret_terms);
// }
// if let Some(r) = self.circ.exit_fn() {
// let ret_term = r.unwrap_term();
// let ret_terms = ret_term.term.terms(self.circ.cir_ctx());
// self.circ
// .cir_ctx()
// .cs
@@ -1348,17 +1398,6 @@ impl CGen {
// .extend(ret_terms);
// }
if let Some(r) = self.circ.exit_fn() {
let ret_term = r.unwrap_term();
let ret_terms = ret_term.term.terms(self.circ.cir_ctx());
self.circ
.cir_ctx()
.cs
.borrow_mut()
.outputs
.extend(ret_terms);
}
// match self.mode {
// Mode::Mpc(_) => {
// let ret_term = r.unwrap_term();

View File

@@ -39,12 +39,16 @@ fn match_arg(name: &String, params: &BTreeMap<String, Term>) -> Term {
fn inline(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Term> {
let mut res: Vec<Term> = Vec::new();
let comp = fs.computations.get(name).unwrap();
for o in comp.outputs.iter().rev() {
println!("Comp: {}", name);
println!("params: {:#?}", params);
for o in comp.outputs.iter() {
println!("o: {}", o);
let mut cache = TermMap::new();
for t in PostOrderIter::new(o.clone()) {
match &t.op {
Op::Var(name, _sort) => {
Op::Var(name, _) => {
let ret = match_arg(name, &params);
println!("ret: {}", ret);
cache.insert(t.clone(), ret.clone());
}
_ => {
@@ -62,6 +66,7 @@ fn inline(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Ter
}
res.push(cache.get(o).unwrap().clone());
}
println!("res: {:#?}", res);
res
}
@@ -73,12 +78,11 @@ pub fn inline_function_calls(
) -> Term {
let mut call_cache: HashMap<Term, Vec<Term>> = HashMap::new();
for t in PostOrderIter::new(term_.clone()) {
println!("inline t: {}", t);
let mut children = Vec::new();
for c in &t.cs {
if let Some(rewritten_c) = rewritten.get(c) {
if call_cache.contains_key(c) {
children.push(call_cache.get_mut(c).unwrap().pop().unwrap().clone());
} else {
if !call_cache.contains_key(c) {
children.push(rewritten_c.clone());
}
} else {
@@ -86,11 +90,23 @@ pub fn inline_function_calls(
}
}
let entry = match &t.op {
Op::Call(name, args, _rets, _) => {
Op::Field(index) => {
assert!(t.cs.len() > 0);
if let Op::Call(..) = &t.cs[0].op {
if call_cache.contains_key((&t.cs[0])) {
call_cache.get(&t.cs[0]).unwrap()[index + 1].clone()
} else {
panic!("Fields on a Call term should return");
}
} else {
term(t.op.clone(), children)
}
}
Op::Call(name, arg_names, arg_sorts, _) => {
println!("Inlining: {}", name);
// Check number of args
let num_args = args.values().fold(0, |sum, x| {
let num_args = arg_sorts.iter().fold(0, |sum, x| {
sum + match x {
Sort::Array(_, _, l) => *l,
_ => 1,
@@ -104,8 +120,8 @@ pub fn inline_function_calls(
);
// Check arg types
let arg_types = args
.values()
let arg_types = arg_sorts
.iter()
.map(|x| match &x {
Sort::Array(_, val_sort, l) => {
let mut res: Vec<Sort> = Vec::new();
@@ -125,8 +141,9 @@ pub fn inline_function_calls(
);
let mut params: BTreeMap<String, Term> = BTreeMap::new();
let arg_keys = args
let arg_keys = arg_names
.iter()
.zip(arg_sorts.iter())
.map(|(n, x)| match &x {
Sort::Array(_, _, l) => {
let mut res: Vec<String> = Vec::new();
@@ -147,6 +164,7 @@ pub fn inline_function_calls(
}
_ => term(t.op.clone(), children),
};
println!("rewritten: {}\n", entry);
rewritten.insert(t.clone(), entry);
}

View File

@@ -58,8 +58,6 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) ->
let _lock = super::term::COLLECT.read().unwrap();
let mut cache = TermCache::new(TERM_CACHE_LIMIT);
for a in &mut comp.outputs {
// println!("cfold: {}", a);
// println!();
// allow unbounded size during a single fold_cache call
cache.resize(std::usize::MAX);
*a = cfold::fold_cache(a, &mut cache, &*ignore.clone());

View File

@@ -70,7 +70,6 @@ use itertools::zip_eq;
#[derive(Clone, PartialEq, Eq, Debug)]
enum TupleTree {
NonTuple(Term),
CallTuple(Term),
Tuple(im::Vector<TupleTree>),
}
@@ -79,7 +78,6 @@ impl TupleTree {
let mut out = Vec::new();
fn rec_unroll_into(t: &TupleTree, out: &mut Vec<Term>) {
match t {
TupleTree::CallTuple(t) => out.push(t.clone()),
TupleTree::NonTuple(t) => out.push(t.clone()),
TupleTree::Tuple(t) => {
for c in t {
@@ -98,9 +96,6 @@ impl TupleTree {
TupleTree::Tuple(tt.iter().map(|c| term_structure(c, iter)).collect())
}
TupleTree::NonTuple(_) => TupleTree::NonTuple(iter.next().expect("bad structure")),
TupleTree::CallTuple(_) => {
TupleTree::CallTuple(iter.next().expect("bad structure"))
}
}
}
term_structure(self, &mut flattened.into_iter())
@@ -113,12 +108,10 @@ impl TupleTree {
}
fn get(&self, i: usize) -> Self {
match self {
TupleTree::CallTuple(cs) => {
TupleTree::CallTuple(term![Op::Select; cs.clone(), bv_lit(i, 32)])
}
TupleTree::NonTuple(cs) => {
panic!("Get ({}) on non-tuple {:?}", i, self)
}
TupleTree::NonTuple(cs) => match cs.op {
Op::Call(..) => TupleTree::NonTuple(term![Op::Field(i); cs.clone()]),
_ => panic!("Get ({}) on non-tuple {:?}", i, self),
},
TupleTree::Tuple(t) => {
assert!(i < t.len());
t.get(i).unwrap().clone()
@@ -127,10 +120,6 @@ impl TupleTree {
}
fn update(&self, i: usize, v: &TupleTree) -> Self {
match self {
TupleTree::CallTuple(cs) => {
let val = v.clone().unwrap_non_tuple();
TupleTree::CallTuple(term![Op::Store; cs.clone(), bv_lit(i, 32), val.clone()])
}
TupleTree::NonTuple(cs) => panic!("Update ({}) on non-tuple {:?}", i, self),
TupleTree::Tuple(t) => {
assert!(i < t.len());
@@ -141,7 +130,6 @@ impl TupleTree {
fn unwrap_non_tuple(self) -> Term {
match self {
TupleTree::NonTuple(t) => t,
TupleTree::CallTuple(t) => t,
_ => panic!("{:?} is tuple!", self),
}
}
@@ -274,10 +262,6 @@ pub fn eliminate_tuples(cs: &mut Computation) {
t.update(*i, &v)
}
Op::Tuple => TupleTree::Tuple(cs.into()),
Op::Call(..) => TupleTree::CallTuple(term(
t.op.clone(),
cs.into_iter().map(|c| c.unwrap_non_tuple()).collect(),
)),
_ => TupleTree::NonTuple(term(
t.op.clone(),
cs.into_iter().map(|c| c.unwrap_non_tuple()).collect(),

View File

@@ -139,13 +139,8 @@ pub enum Op {
/// Map (operation)
Map(Box<Op>),
/// Call a function (name, argument sorts, return sorts, return_name)
Call(
String,
BTreeMap<String, Sort>,
BTreeMap<String, Sort>,
String,
),
/// Call a function (name, argument names, argument sorts, return sorts)
Call(String, Vec<String>, Vec<Sort>, Sort),
}
/// Boolean AND
@@ -257,7 +252,7 @@ impl Op {
Op::Field(_) => Some(1),
Op::Update(_) => Some(2),
Op::Map(op) => op.arity(),
Op::Call(_, args, _, _) => Some(args.len()),
Op::Call(_, _, args, _) => Some(args.len()),
}
}
}

View File

@@ -175,14 +175,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
}
}
}
Op::Call(_, _, ret, ret_name) => {
// let s = ret[ret_name].clone();
// match s {
// Sort::Array(_, val_sort, _) => Ok(*val_sort),
// _ => Ok(s),
// }
Ok(ret[ret_name].clone())
}
Op::Call(_, _, _, ret) => Ok(ret.clone()),
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
}
}
@@ -399,14 +392,14 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
rec_check_raw_helper(&(*op.clone()), &new_a[..])
.map(|val_sort| Sort::Array(Box::new(key_sort), Box::new(val_sort), size))
}
(Op::Call(_, ex_args, ret, ret_name), act_args) => {
(Op::Call(_, _, ex_args, ret), act_args) => {
if ex_args.len() != act_args.len() {
Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len()))
} else {
for ((_, e), a) in ex_args.iter().zip(act_args) {
for (e, a) in ex_args.iter().zip(act_args) {
eq_or(e, a, "in function call")?;
}
Ok(ret[ret_name].clone())
Ok(ret.clone())
}
}
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),