mirror of
https://github.com/circify/circ.git
synced 2026-04-21 03:00:54 -04:00
fixed return type for calls
This commit is contained in:
@@ -6,7 +6,7 @@ int fa(int * c, int a) {
|
||||
}
|
||||
|
||||
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) {
|
||||
int c[5];
|
||||
int c[5] = {0,1,2,3,4};
|
||||
int ret = fa(c, a);
|
||||
int sum = ret;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
|
||||
@@ -247,6 +247,8 @@ fn main() {
|
||||
Opt::Sha,
|
||||
Opt::ConstantFold(Box::new(ignore.clone())),
|
||||
Opt::Flatten,
|
||||
// The function call abstraction creates tuples
|
||||
Opt::Tuple,
|
||||
Opt::Obliv,
|
||||
// The obliv elim pass produces more tuples, that must be eliminated
|
||||
Opt::Tuple,
|
||||
@@ -254,10 +256,10 @@ fn main() {
|
||||
// The linear scan pass produces more tuples, that must be eliminated
|
||||
Opt::Tuple,
|
||||
Opt::ConstantFold(Box::new(ignore.clone())),
|
||||
// Inline Function Calls
|
||||
Opt::InlineCalls,
|
||||
// Binarize nary terms
|
||||
Opt::Binarize,
|
||||
// // Inline Function Calls
|
||||
// Opt::InlineCalls,
|
||||
],
|
||||
)
|
||||
}
|
||||
@@ -286,21 +288,21 @@ fn main() {
|
||||
),
|
||||
};
|
||||
|
||||
// for (name, comp) in cs.computations.iter() {
|
||||
// println!("functions: {}", name);
|
||||
// for t in &comp.outputs {
|
||||
// println!("function term: {}, {}", t, t.uid());
|
||||
// for t1 in PostOrderIter::new(t.clone()) {
|
||||
// println!("term: {}, {}", t1, t1.uid());
|
||||
// for c in t1.cs.iter() {
|
||||
// println!("children: {}, {}", c, c.uid());
|
||||
// }
|
||||
// println!();
|
||||
// }
|
||||
// println!();
|
||||
// }
|
||||
// println!("\n");
|
||||
// }
|
||||
for (name, comp) in cs.computations.iter() {
|
||||
println!("functions: {}", name);
|
||||
for t in &comp.outputs {
|
||||
println!("function term: {}, {}", t, t.uid());
|
||||
// for t1 in PostOrderIter::new(t.clone()) {
|
||||
// println!("term: {}, {}", t1, t1.uid());
|
||||
// for c in t1.cs.iter() {
|
||||
// println!("children: {}, {}", c, c.uid());
|
||||
// }
|
||||
// println!();
|
||||
// }
|
||||
println!();
|
||||
}
|
||||
println!("\n");
|
||||
}
|
||||
|
||||
println!("Done with IR optimization");
|
||||
|
||||
|
||||
@@ -30,10 +30,10 @@ function mpc_test_2 {
|
||||
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+b"
|
||||
}
|
||||
|
||||
# mpc_test_2 2 ./examples/C/mpc/playground.c
|
||||
mpc_test_2 2 ./examples/C/mpc/playground.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
|
||||
# mpc_test_2 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
|
||||
mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
|
||||
# mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
|
||||
|
||||
# # build mpc arithmetic tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c
|
||||
|
||||
@@ -775,27 +775,21 @@ impl<E: Embeddable> Circify<E> {
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the return value of the function, if any.
|
||||
pub fn exit_fn_call(&mut self, ret_names: &Vec<&String>) -> HashMap<String, Val<E::T>> {
|
||||
pub fn exit_fn_call(&mut self, ret_names: &Vec<String>) -> Option<Vec<Val<E::T>>> {
|
||||
if let Some(fn_) = self.fn_stack.last() {
|
||||
let mut rets: HashMap<String, Val<E::T>> = HashMap::new();
|
||||
let mut rets: Vec<Val<E::T>> = Vec::new();
|
||||
// Get return value if possible
|
||||
if fn_.has_return {
|
||||
rets.insert(
|
||||
RET_NAME.to_string(),
|
||||
self.get_value(Loc::local(RET_NAME.to_owned())).unwrap(),
|
||||
);
|
||||
rets.push(self.get_value(Loc::local(RET_NAME.to_owned())).unwrap());
|
||||
}
|
||||
|
||||
// Get references if possible
|
||||
for name in ret_names {
|
||||
rets.insert(
|
||||
name.to_string(),
|
||||
self.get_value(Loc::local(name.to_string())).unwrap(),
|
||||
);
|
||||
rets.push(self.get_value(Loc::local(name.to_string())).unwrap());
|
||||
}
|
||||
|
||||
self.fn_stack.pop().unwrap();
|
||||
rets
|
||||
Some(rets)
|
||||
} else {
|
||||
panic!("No fn to exit")
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ pub fn body_from_func(fn_def: &FunctionDefinition) -> Statement {
|
||||
pub fn fn_info_to_defs(
|
||||
fn_info: &FnInfo,
|
||||
arg_terms: &Vec<Vec<Term>>, // arguments taken at call site
|
||||
) -> (String, BTreeMap<String, Sort>, BTreeMap<String, Sort>) {
|
||||
) -> (String, Vec<String>, Vec<Sort>, BTreeMap<String, Sort>) {
|
||||
let mut rets: BTreeMap<String, Sort> = BTreeMap::new();
|
||||
match &fn_info.ret_ty {
|
||||
Ty::Void => {}
|
||||
@@ -127,7 +127,8 @@ pub fn fn_info_to_defs(
|
||||
}
|
||||
};
|
||||
assert!(fn_info.params.len() == arg_terms.len());
|
||||
let mut params = BTreeMap::new();
|
||||
let mut param_names: Vec<String> = Vec::new();
|
||||
let mut param_sorts: Vec<Sort> = Vec::new();
|
||||
for (param, arg) in fn_info.params.iter().zip(arg_terms.iter()) {
|
||||
let name = param.name.clone();
|
||||
let ty = match ¶m.ty {
|
||||
@@ -140,10 +141,10 @@ pub fn fn_info_to_defs(
|
||||
}
|
||||
_ => param.ty.sort(),
|
||||
};
|
||||
params.insert(name, ty.clone());
|
||||
param_names.push(name);
|
||||
param_sorts.push(ty.clone());
|
||||
}
|
||||
|
||||
(fn_info.name.clone(), params, rets)
|
||||
(fn_info.name.clone(), param_names, param_sorts, rets)
|
||||
}
|
||||
|
||||
pub fn flatten_inits(init: Initializer) -> Vec<Initializer> {
|
||||
|
||||
@@ -78,8 +78,8 @@ impl FrontEnd for C {
|
||||
// generate new context
|
||||
g.circ = Circify::new(Ct::new(i.inputs.clone().map(parser::parse_inputs)));
|
||||
let call = g.function_queue.pop().unwrap();
|
||||
if let Op::Call(name, args, rets, ret_name) = &call.op {
|
||||
g.fn_call(name, args, rets, ret_name);
|
||||
if let Op::Call(name, arg_names, arg_sorts, rets) = &call.op {
|
||||
g.fn_call(name, arg_names, arg_sorts, rets);
|
||||
let comp = g.circ.consume().borrow().clone();
|
||||
|
||||
// println!("fn: {}", name);
|
||||
@@ -826,6 +826,7 @@ impl CGen {
|
||||
|
||||
let ret_ty = f.ret_ty.clone();
|
||||
|
||||
let mut arg_names: Vec<String> = Vec::new();
|
||||
let cargs = arguments
|
||||
.iter()
|
||||
.map(|e| self.gen_expr(e.node.clone()))
|
||||
@@ -833,20 +834,26 @@ impl CGen {
|
||||
let mut cargs_map: HashMap<String, CTerm> = HashMap::new();
|
||||
for (p, c) in f.params.iter().zip(cargs.iter()) {
|
||||
cargs_map.insert(p.name.clone(), c.clone());
|
||||
arg_names.push(p.name.clone());
|
||||
}
|
||||
|
||||
let arg_terms = cargs
|
||||
.iter()
|
||||
.map(|e| e.term.terms(self.circ.cir_ctx()))
|
||||
.collect::<Vec<_>>();
|
||||
let flatten_args = arg_terms.clone().into_iter().flatten().collect::<Vec<_>>();
|
||||
let (name, args, rets) = fn_info_to_defs(&f, &arg_terms);
|
||||
let (name, arg_names, arg_sorts, rets) = fn_info_to_defs(&f, &arg_terms);
|
||||
|
||||
let call_term = term(
|
||||
Op::Call(
|
||||
name.clone(),
|
||||
args.clone(),
|
||||
rets.clone(),
|
||||
"return".to_string(),
|
||||
arg_names.clone(),
|
||||
arg_sorts.clone(),
|
||||
Sort::Tuple(
|
||||
rets.values()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(),
|
||||
),
|
||||
),
|
||||
flatten_args.clone(),
|
||||
);
|
||||
@@ -859,58 +866,71 @@ impl CGen {
|
||||
|
||||
// Rewiring
|
||||
for (ret_name, sort) in rets.iter() {
|
||||
if ret_name != "return" {
|
||||
let call = term(
|
||||
Op::Call(
|
||||
name.clone(),
|
||||
args.clone(),
|
||||
rets.clone(),
|
||||
ret_name.to_string(),
|
||||
),
|
||||
flatten_args.clone(),
|
||||
);
|
||||
|
||||
if let Sort::Array(_, _, l) = sort {
|
||||
let ct = cargs_map.get(ret_name).unwrap();
|
||||
if let CTermData::Array(_, id) = ct.term {
|
||||
self.circ.replace(id.unwrap(), call.clone());
|
||||
// self.circ.assign(l, Val::Term(val));
|
||||
// for i in 0..*l {
|
||||
// let updated_idx = bv_lit(i as i32, 32);
|
||||
// // TODO: index calculation
|
||||
// self.circ.store(id.unwrap(), updated_idx, call.clone());
|
||||
// }
|
||||
} else {
|
||||
unimplemented!("This should only be handling ptrs to arrays");
|
||||
}
|
||||
if let Sort::Array(_, _, l) = sort {
|
||||
let ct = cargs_map.get(ret_name).unwrap();
|
||||
if let CTermData::Array(_, id) = ct.term {
|
||||
self.circ.replace(id.unwrap(), call_term.clone());
|
||||
} else {
|
||||
unimplemented!("This should only be handling ptrs to arrays");
|
||||
}
|
||||
|
||||
// println!("CT: {}", ct.term.term());
|
||||
// // self
|
||||
// // .circ
|
||||
// // .assign(l, Val::Term(val))
|
||||
// // .map_err(|e| format!("{}", e))?
|
||||
// // .unwrap_term()
|
||||
// unimplemented!();
|
||||
|
||||
// // if let CTermData::Array(_, id) = ct.term {
|
||||
|
||||
// // }
|
||||
// // for i in 0..*l {
|
||||
// // let updated_idx = bv_lit(i as i32, 32);
|
||||
// // self.circ.store(id.unwrap(), updated_idx, call.clone());
|
||||
// // }
|
||||
// // } else {
|
||||
// // unimplemented!("This should only be handling ptrs to arrays");
|
||||
// // }
|
||||
// } else {
|
||||
// unimplemented!("This should only be handling ptrs to arrays");
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
// // Rewiring
|
||||
// for (ret_name, sort) in rets.iter() {
|
||||
// println!("retname: {}", ret_name);
|
||||
// if ret_name != "return" {
|
||||
// let call = term(
|
||||
// Op::Call(
|
||||
// name.clone(),
|
||||
// args.clone(),
|
||||
// rets.clone(),
|
||||
// ret_name.to_string(),
|
||||
// ),
|
||||
// flatten_args.clone(),
|
||||
// );
|
||||
|
||||
// if let Sort::Array(_, _, l) = sort {
|
||||
// let ct = cargs_map.get(ret_name).unwrap();
|
||||
// if let CTermData::Array(_, id) = ct.term {
|
||||
// self.circ.replace(id.unwrap(), call.clone());
|
||||
// // self.circ.assign(l, Val::Term(val));
|
||||
// // for i in 0..*l {
|
||||
// // let updated_idx = bv_lit(i as i32, 32);
|
||||
// // // TODO: index calculation
|
||||
// // self.circ.store(id.unwrap(), updated_idx, call.clone());
|
||||
// // }
|
||||
// } else {
|
||||
// unimplemented!("This should only be handling ptrs to arrays");
|
||||
// }
|
||||
// } else {
|
||||
// unimplemented!("This should only be handling ptrs to arrays");
|
||||
// }
|
||||
|
||||
// // println!("CT: {}", ct.term.term());
|
||||
// // // self
|
||||
// // // .circ
|
||||
// // // .assign(l, Val::Term(val))
|
||||
// // // .map_err(|e| format!("{}", e))?
|
||||
// // // .unwrap_term()
|
||||
// // unimplemented!();
|
||||
|
||||
// // // if let CTermData::Array(_, id) = ct.term {
|
||||
|
||||
// // // }
|
||||
// // // for i in 0..*l {
|
||||
// // // let updated_idx = bv_lit(i as i32, 32);
|
||||
// // // self.circ.store(id.unwrap(), updated_idx, call.clone());
|
||||
// // // }
|
||||
// // // } else {
|
||||
// // // unimplemented!("This should only be handling ptrs to arrays");
|
||||
// // // }
|
||||
// // } else {
|
||||
// // unimplemented!("This should only be handling ptrs to arrays");
|
||||
// // }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Return value
|
||||
let ret = match ret_ty {
|
||||
Ty::Void | Ty::Bool => cterm(CTermData::Bool(call_term)),
|
||||
@@ -1286,18 +1306,17 @@ impl CGen {
|
||||
fn fn_call(
|
||||
&mut self,
|
||||
name: &String,
|
||||
args: &BTreeMap<String, Sort>,
|
||||
rets: &BTreeMap<String, Sort>,
|
||||
ret_name: &String,
|
||||
arg_names: &Vec<String>,
|
||||
arg_sorts: &Vec<Sort>,
|
||||
rets: &Sort,
|
||||
) {
|
||||
debug!("Call: {}", name);
|
||||
println!("Call: {}", name);
|
||||
// for (n, a) in args {
|
||||
// println!("args: {}, {}", n, a);
|
||||
// }
|
||||
// for (r, s) in rets.iter() {
|
||||
// println!("ret: {}, {}", r, s);
|
||||
// }
|
||||
|
||||
let mut arg_map: BTreeMap<String, Sort> = BTreeMap::new();
|
||||
for (n, s) in arg_names.iter().zip(arg_sorts.iter()) {
|
||||
arg_map.insert(n.to_string(), s.clone());
|
||||
}
|
||||
|
||||
// Get function types
|
||||
let f = self
|
||||
@@ -1313,16 +1332,23 @@ impl CGen {
|
||||
};
|
||||
self.circ.enter_fn(name.to_owned(), ret_ty);
|
||||
|
||||
// Keep track of the names of arguments that are references
|
||||
let mut ret_names: Vec<String> = Vec::new();
|
||||
|
||||
// define input parameters
|
||||
assert!(args.len() == f.params.len());
|
||||
assert!(arg_map.len() == f.params.len());
|
||||
for param in f.params {
|
||||
let p_name = param.name;
|
||||
assert!(args.contains_key(&p_name));
|
||||
let s = args.get(&p_name).unwrap();
|
||||
assert!(arg_map.contains_key(&p_name));
|
||||
let s = arg_map.get(&p_name).unwrap();
|
||||
let p_ty = match param.ty {
|
||||
Ty::Ptr(_, t) => {
|
||||
if let Sort::Array(_, _, len) = s {
|
||||
let dims = vec![*len];
|
||||
|
||||
// Add reference
|
||||
ret_names.push(p_name.clone());
|
||||
|
||||
Ty::Array(*len, dims, t)
|
||||
} else {
|
||||
panic!("Ptr type does not match with Array sort: {}", s)
|
||||
@@ -1336,10 +1362,34 @@ impl CGen {
|
||||
|
||||
self.gen_stmt(f.body.clone());
|
||||
|
||||
// let ret_names = &rets.keys().collect::<Vec<&String>>();
|
||||
// let rets = self.circ.exit_fn_call(ret_names);
|
||||
// for (name, val) in rets {
|
||||
// let ret_terms = val.unwrap_term().term.terms(self.circ.cir_ctx());
|
||||
if let Some(returns) = self.circ.exit_fn_call(&ret_names) {
|
||||
let ret_terms = returns
|
||||
.into_iter()
|
||||
.map(|x| x.unwrap_term().term.terms(self.circ.cir_ctx()))
|
||||
.flatten()
|
||||
.collect::<Vec<Term>>();
|
||||
self.circ
|
||||
.cir_ctx()
|
||||
.cs
|
||||
.borrow_mut()
|
||||
.outputs
|
||||
.push(term(Op::Tuple, ret_terms));
|
||||
}
|
||||
|
||||
// for (name, val) in returns {
|
||||
// println!("name: {}", name);
|
||||
// // let ret_terms = val.unwrap_term().term.terms(self.circ.cir_ctx());
|
||||
// // self.circ
|
||||
// // .cir_ctx()
|
||||
// // .cs
|
||||
// // .borrow_mut()
|
||||
// // .outputs
|
||||
// // .extend(ret_terms);
|
||||
// }
|
||||
|
||||
// if let Some(r) = self.circ.exit_fn() {
|
||||
// let ret_term = r.unwrap_term();
|
||||
// let ret_terms = ret_term.term.terms(self.circ.cir_ctx());
|
||||
// self.circ
|
||||
// .cir_ctx()
|
||||
// .cs
|
||||
@@ -1348,17 +1398,6 @@ impl CGen {
|
||||
// .extend(ret_terms);
|
||||
// }
|
||||
|
||||
if let Some(r) = self.circ.exit_fn() {
|
||||
let ret_term = r.unwrap_term();
|
||||
let ret_terms = ret_term.term.terms(self.circ.cir_ctx());
|
||||
self.circ
|
||||
.cir_ctx()
|
||||
.cs
|
||||
.borrow_mut()
|
||||
.outputs
|
||||
.extend(ret_terms);
|
||||
}
|
||||
|
||||
// match self.mode {
|
||||
// Mode::Mpc(_) => {
|
||||
// let ret_term = r.unwrap_term();
|
||||
|
||||
@@ -39,12 +39,16 @@ fn match_arg(name: &String, params: &BTreeMap<String, Term>) -> Term {
|
||||
fn inline(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Term> {
|
||||
let mut res: Vec<Term> = Vec::new();
|
||||
let comp = fs.computations.get(name).unwrap();
|
||||
for o in comp.outputs.iter().rev() {
|
||||
println!("Comp: {}", name);
|
||||
println!("params: {:#?}", params);
|
||||
for o in comp.outputs.iter() {
|
||||
println!("o: {}", o);
|
||||
let mut cache = TermMap::new();
|
||||
for t in PostOrderIter::new(o.clone()) {
|
||||
match &t.op {
|
||||
Op::Var(name, _sort) => {
|
||||
Op::Var(name, _) => {
|
||||
let ret = match_arg(name, ¶ms);
|
||||
println!("ret: {}", ret);
|
||||
cache.insert(t.clone(), ret.clone());
|
||||
}
|
||||
_ => {
|
||||
@@ -62,6 +66,7 @@ fn inline(name: &str, params: BTreeMap<String, Term>, fs: &Functions) -> Vec<Ter
|
||||
}
|
||||
res.push(cache.get(o).unwrap().clone());
|
||||
}
|
||||
println!("res: {:#?}", res);
|
||||
res
|
||||
}
|
||||
|
||||
@@ -73,12 +78,11 @@ pub fn inline_function_calls(
|
||||
) -> Term {
|
||||
let mut call_cache: HashMap<Term, Vec<Term>> = HashMap::new();
|
||||
for t in PostOrderIter::new(term_.clone()) {
|
||||
println!("inline t: {}", t);
|
||||
let mut children = Vec::new();
|
||||
for c in &t.cs {
|
||||
if let Some(rewritten_c) = rewritten.get(c) {
|
||||
if call_cache.contains_key(c) {
|
||||
children.push(call_cache.get_mut(c).unwrap().pop().unwrap().clone());
|
||||
} else {
|
||||
if !call_cache.contains_key(c) {
|
||||
children.push(rewritten_c.clone());
|
||||
}
|
||||
} else {
|
||||
@@ -86,11 +90,23 @@ pub fn inline_function_calls(
|
||||
}
|
||||
}
|
||||
let entry = match &t.op {
|
||||
Op::Call(name, args, _rets, _) => {
|
||||
Op::Field(index) => {
|
||||
assert!(t.cs.len() > 0);
|
||||
if let Op::Call(..) = &t.cs[0].op {
|
||||
if call_cache.contains_key((&t.cs[0])) {
|
||||
call_cache.get(&t.cs[0]).unwrap()[index + 1].clone()
|
||||
} else {
|
||||
panic!("Fields on a Call term should return");
|
||||
}
|
||||
} else {
|
||||
term(t.op.clone(), children)
|
||||
}
|
||||
}
|
||||
Op::Call(name, arg_names, arg_sorts, _) => {
|
||||
println!("Inlining: {}", name);
|
||||
|
||||
// Check number of args
|
||||
let num_args = args.values().fold(0, |sum, x| {
|
||||
let num_args = arg_sorts.iter().fold(0, |sum, x| {
|
||||
sum + match x {
|
||||
Sort::Array(_, _, l) => *l,
|
||||
_ => 1,
|
||||
@@ -104,8 +120,8 @@ pub fn inline_function_calls(
|
||||
);
|
||||
|
||||
// Check arg types
|
||||
let arg_types = args
|
||||
.values()
|
||||
let arg_types = arg_sorts
|
||||
.iter()
|
||||
.map(|x| match &x {
|
||||
Sort::Array(_, val_sort, l) => {
|
||||
let mut res: Vec<Sort> = Vec::new();
|
||||
@@ -125,8 +141,9 @@ pub fn inline_function_calls(
|
||||
);
|
||||
|
||||
let mut params: BTreeMap<String, Term> = BTreeMap::new();
|
||||
let arg_keys = args
|
||||
let arg_keys = arg_names
|
||||
.iter()
|
||||
.zip(arg_sorts.iter())
|
||||
.map(|(n, x)| match &x {
|
||||
Sort::Array(_, _, l) => {
|
||||
let mut res: Vec<String> = Vec::new();
|
||||
@@ -147,6 +164,7 @@ pub fn inline_function_calls(
|
||||
}
|
||||
_ => term(t.op.clone(), children),
|
||||
};
|
||||
println!("rewritten: {}\n", entry);
|
||||
rewritten.insert(t.clone(), entry);
|
||||
}
|
||||
|
||||
|
||||
@@ -58,8 +58,6 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut fs: Functions, optimizations: I) ->
|
||||
let _lock = super::term::COLLECT.read().unwrap();
|
||||
let mut cache = TermCache::new(TERM_CACHE_LIMIT);
|
||||
for a in &mut comp.outputs {
|
||||
// println!("cfold: {}", a);
|
||||
// println!();
|
||||
// allow unbounded size during a single fold_cache call
|
||||
cache.resize(std::usize::MAX);
|
||||
*a = cfold::fold_cache(a, &mut cache, &*ignore.clone());
|
||||
|
||||
@@ -70,7 +70,6 @@ use itertools::zip_eq;
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
enum TupleTree {
|
||||
NonTuple(Term),
|
||||
CallTuple(Term),
|
||||
Tuple(im::Vector<TupleTree>),
|
||||
}
|
||||
|
||||
@@ -79,7 +78,6 @@ impl TupleTree {
|
||||
let mut out = Vec::new();
|
||||
fn rec_unroll_into(t: &TupleTree, out: &mut Vec<Term>) {
|
||||
match t {
|
||||
TupleTree::CallTuple(t) => out.push(t.clone()),
|
||||
TupleTree::NonTuple(t) => out.push(t.clone()),
|
||||
TupleTree::Tuple(t) => {
|
||||
for c in t {
|
||||
@@ -98,9 +96,6 @@ impl TupleTree {
|
||||
TupleTree::Tuple(tt.iter().map(|c| term_structure(c, iter)).collect())
|
||||
}
|
||||
TupleTree::NonTuple(_) => TupleTree::NonTuple(iter.next().expect("bad structure")),
|
||||
TupleTree::CallTuple(_) => {
|
||||
TupleTree::CallTuple(iter.next().expect("bad structure"))
|
||||
}
|
||||
}
|
||||
}
|
||||
term_structure(self, &mut flattened.into_iter())
|
||||
@@ -113,12 +108,10 @@ impl TupleTree {
|
||||
}
|
||||
fn get(&self, i: usize) -> Self {
|
||||
match self {
|
||||
TupleTree::CallTuple(cs) => {
|
||||
TupleTree::CallTuple(term![Op::Select; cs.clone(), bv_lit(i, 32)])
|
||||
}
|
||||
TupleTree::NonTuple(cs) => {
|
||||
panic!("Get ({}) on non-tuple {:?}", i, self)
|
||||
}
|
||||
TupleTree::NonTuple(cs) => match cs.op {
|
||||
Op::Call(..) => TupleTree::NonTuple(term![Op::Field(i); cs.clone()]),
|
||||
_ => panic!("Get ({}) on non-tuple {:?}", i, self),
|
||||
},
|
||||
TupleTree::Tuple(t) => {
|
||||
assert!(i < t.len());
|
||||
t.get(i).unwrap().clone()
|
||||
@@ -127,10 +120,6 @@ impl TupleTree {
|
||||
}
|
||||
fn update(&self, i: usize, v: &TupleTree) -> Self {
|
||||
match self {
|
||||
TupleTree::CallTuple(cs) => {
|
||||
let val = v.clone().unwrap_non_tuple();
|
||||
TupleTree::CallTuple(term![Op::Store; cs.clone(), bv_lit(i, 32), val.clone()])
|
||||
}
|
||||
TupleTree::NonTuple(cs) => panic!("Update ({}) on non-tuple {:?}", i, self),
|
||||
TupleTree::Tuple(t) => {
|
||||
assert!(i < t.len());
|
||||
@@ -141,7 +130,6 @@ impl TupleTree {
|
||||
fn unwrap_non_tuple(self) -> Term {
|
||||
match self {
|
||||
TupleTree::NonTuple(t) => t,
|
||||
TupleTree::CallTuple(t) => t,
|
||||
_ => panic!("{:?} is tuple!", self),
|
||||
}
|
||||
}
|
||||
@@ -274,10 +262,6 @@ pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
t.update(*i, &v)
|
||||
}
|
||||
Op::Tuple => TupleTree::Tuple(cs.into()),
|
||||
Op::Call(..) => TupleTree::CallTuple(term(
|
||||
t.op.clone(),
|
||||
cs.into_iter().map(|c| c.unwrap_non_tuple()).collect(),
|
||||
)),
|
||||
_ => TupleTree::NonTuple(term(
|
||||
t.op.clone(),
|
||||
cs.into_iter().map(|c| c.unwrap_non_tuple()).collect(),
|
||||
|
||||
@@ -139,13 +139,8 @@ pub enum Op {
|
||||
/// Map (operation)
|
||||
Map(Box<Op>),
|
||||
|
||||
/// Call a function (name, argument sorts, return sorts, return_name)
|
||||
Call(
|
||||
String,
|
||||
BTreeMap<String, Sort>,
|
||||
BTreeMap<String, Sort>,
|
||||
String,
|
||||
),
|
||||
/// Call a function (name, argument names, argument sorts, return sorts)
|
||||
Call(String, Vec<String>, Vec<Sort>, Sort),
|
||||
}
|
||||
|
||||
/// Boolean AND
|
||||
@@ -257,7 +252,7 @@ impl Op {
|
||||
Op::Field(_) => Some(1),
|
||||
Op::Update(_) => Some(2),
|
||||
Op::Map(op) => op.arity(),
|
||||
Op::Call(_, args, _, _) => Some(args.len()),
|
||||
Op::Call(_, _, args, _) => Some(args.len()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,14 +175,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::Call(_, _, ret, ret_name) => {
|
||||
// let s = ret[ret_name].clone();
|
||||
// match s {
|
||||
// Sort::Array(_, val_sort, _) => Ok(*val_sort),
|
||||
// _ => Ok(s),
|
||||
// }
|
||||
Ok(ret[ret_name].clone())
|
||||
}
|
||||
Op::Call(_, _, _, ret) => Ok(ret.clone()),
|
||||
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
|
||||
}
|
||||
}
|
||||
@@ -399,14 +392,14 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
|
||||
rec_check_raw_helper(&(*op.clone()), &new_a[..])
|
||||
.map(|val_sort| Sort::Array(Box::new(key_sort), Box::new(val_sort), size))
|
||||
}
|
||||
(Op::Call(_, ex_args, ret, ret_name), act_args) => {
|
||||
(Op::Call(_, _, ex_args, ret), act_args) => {
|
||||
if ex_args.len() != act_args.len() {
|
||||
Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len()))
|
||||
} else {
|
||||
for ((_, e), a) in ex_args.iter().zip(act_args) {
|
||||
for (e, a) in ex_args.iter().zip(act_args) {
|
||||
eq_or(e, a, "in function call")?;
|
||||
}
|
||||
Ok(ret[ret_name].clone())
|
||||
Ok(ret.clone())
|
||||
}
|
||||
}
|
||||
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
|
||||
|
||||
Reference in New Issue
Block a user