diff --git a/scripts/aby_tests/c_test_aby.py b/scripts/aby_tests/c_test_aby.py index 7f8f04b3..e2c6c6fb 100755 --- a/scripts/aby_tests/c_test_aby.py +++ b/scripts/aby_tests/c_test_aby.py @@ -31,8 +31,6 @@ if __name__ == "__main__": cryptonets_tests + \ histogram_tests - tests = ts - # TODO: add support unsigned + int promotion # unsigned_arithmetic_tests + \ diff --git a/scripts/build_mpc_c_test.zsh b/scripts/build_mpc_c_test.zsh index d1e85e35..e3dfabcd 100755 --- a/scripts/build_mpc_c_test.zsh +++ b/scripts/build_mpc_c_test.zsh @@ -36,103 +36,99 @@ function mpc_test_bool { RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "b" } -# mpc_test 2 ./examples/C/mpc/playground.c -# mpc_test 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets.c +# build mpc arithmetic tests +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c + +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c +mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c + +# build nary arithmetic tests +mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c + +# build bitwise tests +mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c +mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c +mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c + +# build boolean tests +mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c +mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c +mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c + +# build nary boolean tests +mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c + +# build const tests +mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c +mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c + +# build if statement tests +mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c +mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c +mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c + +# build shift tests +mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c +mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c + +# build div tests +mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c + +# build array tests +mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c +mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c +mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c + +# build circ/compiler array tests +mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c +mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c +mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c +mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c +mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c + +# build function tests +mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c + +# build struct tests +mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c + +# build matrix tests +mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c + +# build ptr tests +mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c +mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c + +# build misc tests +mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c +mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c + +# build hycc benchmarks +mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c +mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c +mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c +mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c +mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c +mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join2.c +mpc_test 2 ./examples/C/mpc/benchmarks/db/db_merge.c +mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist.c +mpc_test_2 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets.c + +# build OPA benchmarks mpc_test_2 2 ./examples/C/mpc/benchmarks/histogram/histogram.c -# # build mpc arithmetic tests -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c -# # mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c - -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c -# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c - -# # build nary arithmetic tests -# mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c - -# # build bitwise tests -# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c -# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c -# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c - -# # build boolean tests -# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c -# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c -# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c - -# # build nary boolean tests -# mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c - -# # build const tests -# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c -# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c - -# # build if statement tests -# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c -# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c -# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c - -# # build shift tests -# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c -# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c - -# # build div tests -# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c - -# # build array tests -# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c -# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c -# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c - -# # build circ/compiler array tests -# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c -# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c -# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c -# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c -# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c - -# # build function tests -# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c - -# # build struct tests -# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c - -# # build matrix tests -# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c - -# # build ptr tests -# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c -# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c - -# # build misc tests -# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c -# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c - -# # build hycc benchmarks -# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c -# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c -# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c -# mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c -# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c -# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join2.c -# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_merge.c -# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist.c -# mpc_test 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets.c - -# # build OPA benchmarks -# mpc_test_2 2 ./examples/C/mpc/benchmarks/histogram/histogram.c - # # build hycc benchmarks bool-only diff --git a/src/target/aby/call_site_similarity.rs b/src/target/aby/call_site_similarity.rs index 652a7c61..7ee1961d 100644 --- a/src/target/aby/call_site_similarity.rs +++ b/src/target/aby/call_site_similarity.rs @@ -61,12 +61,9 @@ pub fn call_site_similarity(fs: &Functions) -> Vec> { // Clean input and output terms let mut call_sites: HashMap<(Vec, Vec), Vec> = HashMap::new(); - for (c, (i, o)) in call_term_map { let input_ops = i.iter().map(|x| x.op.clone()).collect::>(); - let mut output_ops = o.iter().map(|x| x.op.clone()).collect::>(); - output_ops.sort(); - + let output_ops = o.iter().map(|x| x.op.clone()).collect::>(); let key = (input_ops, output_ops); // longest prefix matching? diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index 8379ccc5..3d89cfc2 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -13,6 +13,7 @@ use crate::target::aby::assignment::ilp::assign; use crate::target::aby::assignment::SharingMap; use crate::target::aby::utils::*; use std::collections::HashMap; +use std::collections::HashSet; use std::fs; use std::io; use std::path::Path; @@ -22,6 +23,7 @@ use super::assignment::assign_all_yao; use super::assignment::assign_arithmetic_and_boolean; use super::assignment::assign_arithmetic_and_yao; use super::assignment::assign_greedy; +use super::assignment::ShareType; use super::call_site_similarity::call_site_similarity; @@ -41,11 +43,13 @@ struct ToABY<'a> { share_cnt: i32, // Cache cache: HashMap<(Op, Vec), Vec>, + // Const Cache + const_cache: HashMap>, // Outputs bytecode_input: Vec, bytecode_output: Vec, - const_output: Vec, - share_output: Vec, + const_output: HashSet, + share_output: HashSet, } impl Drop for ToABY<'_> { @@ -73,17 +77,24 @@ impl<'a> ToABY<'a> { term_to_shares: TermMap::new(), share_cnt: 0, cache: HashMap::new(), + const_cache: HashMap::new(), bytecode_input: Vec::new(), bytecode_output: Vec::new(), - const_output: Vec::new(), - share_output: Vec::new(), + const_output: HashSet::new(), + share_output: HashSet::new(), } } fn write_const_output(&mut self, flush: bool) { if flush || self.const_output.len() >= WRITE_SIZE { let const_output_path = get_path(self.path, &self.lang, "const", false); - write_lines(&const_output_path, &self.const_output); + let mut lines = self + .const_output + .clone() + .into_iter() + .collect::>(); + lines.sort(); + write_lines(&const_output_path, &lines); self.const_output.clear(); } } @@ -104,7 +115,13 @@ impl<'a> ToABY<'a> { fn write_share_output(&mut self, flush: bool) { if flush || self.share_output.len() >= WRITE_SIZE { let share_output_path = get_path(self.path, &self.lang, "share_map", false); - write_lines(&share_output_path, &self.share_output); + let mut lines = self + .share_output + .clone() + .into_iter() + .collect::>(); + lines.sort(); + write_lines(&share_output_path, &lines); self.share_output.clear(); } } @@ -144,83 +161,166 @@ impl<'a> ToABY<'a> { } } - fn write_share(&mut self, t: &Term, s: i32) { + fn get_term_share_type(&self, t: &Term) -> ShareType { let s_map = self.s_map.get(&self.curr_comp).unwrap(); - let share_type = s_map.get(&t).unwrap().char(); + *s_map.get(&t).unwrap() + } + + fn insert_const(&mut self, t: &Term) { + if !self.const_cache.contains_key(&t) { + let mut const_map: HashMap = HashMap::new(); + + // a type + let s_a = self.share_cnt; + const_map.insert(ShareType::Arithmetic, s_a); + self.share_cnt += 1; + + // b type + let s_b = self.share_cnt; + const_map.insert(ShareType::Boolean, s_b); + self.share_cnt += 1; + + // y type + let s_y = self.share_cnt; + const_map.insert(ShareType::Yao, s_y); + self.share_cnt += 1; + + self.const_cache.insert(t.clone(), const_map); + } + } + + fn output_const_share(&mut self, t: &Term, to_share_type: ShareType) -> i32 { + if self.const_cache.contains_key(&t) { + let output_share = *self + .const_cache + .get(&t) + .unwrap() + .get(&to_share_type) + .unwrap(); + let op = "CONS"; + + match &t.op { + Op::Const(Value::BitVector(b)) => { + let value = b.as_sint(); + let bitlen = 32; + let line = format!("2 1 {} {} {} {}\n", value, bitlen, output_share, op); + self.const_output.insert(line); + } + Op::Const(Value::Bool(b)) => { + let value = *b as i32; + let bitlen = 1; + let line = format!("2 1 {} {} {} {}\n", value, bitlen, output_share, op); + self.const_output.insert(line); + } + _ => todo!(), + }; + + // Add to share map + let line = format!("{} {}\n", output_share, to_share_type.char()); + self.share_output.insert(line); + + output_share + } else { + panic!("const cache does not contain term: {}", t); + } + } + + // fn output_const_shares(&mut self, t: &Term, to_share_type: ShareType) -> i32 { + // if self.const_cache.contains_key(&t) { + // let output_share = *self + // .const_cache + // .get(&t) + // .unwrap() + // .get(&to_share_type) + // .unwrap(); + // let op = "CONS"; + + // match &t.op { + // Op::Const(Value::BitVector(b)) => { + // let value = b.as_sint(); + // let bitlen = 32; + // let line = format!("2 1 {} {} {} {}\n", value, bitlen, output_share, op); + // self.const_output.insert(line); + // } + // Op::Const(Value::Bool(b)) => { + // let value = *b as i32; + // let bitlen = 1; + // let line = format!("2 1 {} {} {} {}\n", value, bitlen, output_share, op); + // self.const_output.insert(line); + // } + // _ => todo!(), + // }; + + // // Add to share map + // let line = format!("{} {}\n", output_share, to_share_type.char()); + // self.share_output.insert(line); + + // output_share + // } else { + // panic!("const cache does not contain term: {}", t); + // } + // } + + fn write_share(&mut self, t: &Term, s: i32) { + let share_type = self.get_term_share_type(t).char(); let line = format!("{} {}\n", s, share_type); - self.share_output.push(line); + self.share_output.insert(line); } fn write_shares(&mut self, t: &Term, shares: &Vec) { - let s_map = self.s_map.get(&self.curr_comp).unwrap(); - let share_type = s_map.get(&t).unwrap().char(); + let share_type = self.get_term_share_type(t).char(); for s in shares { let line = format!("{} {}\n", s, share_type); - self.share_output.push(line); + self.share_output.insert(line); } } // TODO: Rust ENTRY api on maps - fn get_new_share(&mut self, t: &Term, p: &Term) -> i32 { - match self.term_to_shares.get(t) { - Some(v) => { - assert!(v.len() == 1); - v[0] - } - None => { - let s = self.share_cnt; - self.term_to_shares.insert(t.clone(), [s].to_vec()); - self.share_cnt += 1; + fn get_share(&mut self, t: &Term, to_share_type: ShareType) -> i32 { + if t.is_const() { + self.output_const_share(t, to_share_type) + } else { + match self.term_to_shares.get(t) { + Some(v) => { + assert!(v.len() == 1); + v[0] + } + None => { + let s = self.share_cnt; + self.term_to_shares.insert(t.clone(), [s].to_vec()); + self.share_cnt += 1; - // Write share - let s_map = self.s_map.get(&self.curr_comp).unwrap(); - let share_type = s_map.get(&p).unwrap().char(); - let line = format!("{} {}\n", s, share_type); - self.share_output.push(line); + // Write share + self.write_share(t, s); - s + s + } } } } - // TODO: Rust ENTRY api on maps - fn get_share(&mut self, t: &Term) -> i32 { - match self.term_to_shares.get(t) { - Some(v) => { - assert!(v.len() == 1); - v[0] - } - None => { - let s = self.share_cnt; - self.term_to_shares.insert(t.clone(), [s].to_vec()); - self.share_cnt += 1; + fn get_shares(&mut self, t: &Term, to_share_type: ShareType) -> Vec { + if t.is_const() && check(t).is_scalar() { + vec![self.output_const_share(t, to_share_type)] + } else { + match self.term_to_shares.get(t) { + Some(v) => v.clone(), + None => { + let sort = check(t); + let num_shares = self.get_sort_len(&sort) as i32; - // Write share - self.write_share(t, s); + let shares: Vec = (0..num_shares) + .map(|x| x + self.share_cnt) + .collect::>(); + self.term_to_shares.insert(t.clone(), shares.clone()); - s - } - } - } + // Write shares + self.write_shares(t, &shares); - fn get_shares(&mut self, t: &Term) -> Vec { - match self.term_to_shares.get(t) { - Some(v) => v.clone(), - None => { - let sort = check(t); - let num_shares = self.get_sort_len(&sort) as i32; + self.share_cnt += num_shares; - let shares: Vec = (0..num_shares) - .map(|x| x + self.share_cnt) - .collect::>(); - self.term_to_shares.insert(t.clone(), shares.clone()); - - // Write shares - self.write_shares(t, &shares); - - self.share_cnt += num_shares; - - shares + shares + } } } } @@ -261,30 +361,31 @@ impl<'a> ToABY<'a> { fn embed_eq(&mut self, t: &Term) { let op = "EQ"; - let a = self.get_share(&t.cs[0]); - let b = self.get_share(&t.cs[1]); + let to_share_type = self.get_term_share_type(t); + let a = self.get_share(&t.cs[0], to_share_type); + let b = self.get_share(&t.cs[1], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(t); - self.cache.insert(key, s.clone()); - let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + let s = self.get_share(t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, b, s, op); self.bytecode_output.push(line); }; } fn embed_bool(&mut self, t: Term) { - let s = self.get_share(&t); + let to_share_type = self.get_term_share_type(&t); match &t.op { Op::Var(name, Sort::Bool) => { let md = self.get_md(); if !self.inputs.contains(&t) && md.input_vis.contains_key(name) { let term_name = ToABY::get_var_name(&name); let vis = self.unwrap_vis(name); - let s = self.get_share(&t); + let s = self.get_share(&t, to_share_type); let op = "IN"; if vis == PUBLIC { @@ -298,43 +399,45 @@ impl<'a> ToABY<'a> { self.inputs.push(t.clone()); } } - Op::Const(Value::Bool(b)) => { - let op = "CONS_bool"; - let line = format!("1 1 {} {} {}\n", *b as i32, s, op); - self.const_output.push(line); + Op::Const(_) => { + self.insert_const(&t); + // let op = "CONS_bool"; + // let line = format!("1 1 {} {} {}\n", *b as i32, s, op); + // self.const_output.insert(line); } Op::Eq => { self.embed_eq(&t); } Op::Ite => { let op = "MUX"; - let sel = self.get_share(&t.cs[0]); - let a = self.get_share(&t.cs[1]); - let b = self.get_share(&t.cs[2]); + let to_share_type = self.get_term_share_type(&t); + let sel = self.get_share(&t.cs[0], to_share_type); + let a = self.get_share(&t.cs[1], to_share_type); + let b = self.get_share(&t.cs[2], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op); self.bytecode_output.push(line); }; } Op::Not => { let op = "NOT"; - let a = self.get_share(&t.cs[0]); + let a = self.get_share(&t.cs[0], to_share_type); let key = (t.op.clone(), vec![a]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("1 1 {} {} {}\n", a, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("1 1 {} {} {}\n", a, s, op); self.bytecode_output.push(line); }; } @@ -344,7 +447,7 @@ impl<'a> ToABY<'a> { // If t.cs len is 1, just output that term // This is to bypass adding an AND gate with a single conditional term // Refer to pub fn condition() in src/circify/mod.rs - let a = self.get_share(&t.cs[0]); + let a = self.get_share(&t.cs[0], to_share_type); match o { BoolNaryOp::And => self.term_to_shares.insert(t.clone(), vec![a]), _ => { @@ -358,17 +461,17 @@ impl<'a> ToABY<'a> { BoolNaryOp::Xor => "XOR", }; - let a = self.get_share(&t.cs[0]); - let b = self.get_share(&t.cs[1]); + let a = self.get_share(&t.cs[0], to_share_type); + let b = self.get_share(&t.cs[1], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, b, s, op); self.bytecode_output.push(line); }; } @@ -382,17 +485,17 @@ impl<'a> ToABY<'a> { _ => panic!("Non-field in bool BvBinPred: {}", o), }; - let a = self.get_share(&t.cs[0]); - let b = self.get_share(&t.cs[1]); + let a = self.get_share(&t.cs[0], to_share_type); + let b = self.get_share(&t.cs[1], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, b, s, op); self.bytecode_output.push(line); }; } @@ -401,13 +504,14 @@ impl<'a> ToABY<'a> { } fn embed_bv(&mut self, t: Term) { + let to_share_type = self.get_term_share_type(&t); match &t.op { Op::Var(name, Sort::BitVector(_)) => { let md = self.get_md(); if !self.inputs.contains(&t) && md.input_vis.contains_key(name) { let term_name = ToABY::get_var_name(&name); let vis = self.unwrap_vis(name); - let s = self.get_share(&t); + let s = self.get_share(&t, to_share_type); let op = "IN"; if vis == PUBLIC { @@ -421,26 +525,29 @@ impl<'a> ToABY<'a> { self.inputs.push(t.clone()); } } - Op::Const(Value::BitVector(b)) => { - let s = self.get_share(&t); - let op = "CONS_bv"; - let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); - self.const_output.push(line); + Op::Const(Value::BitVector(_)) => { + // create all three shares + self.insert_const(&t); + + // let s = self.get_share(&t); + // let op = "CONS_bv"; + // let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); + // self.const_output.push(line); } Op::Ite => { let op = "MUX"; - let sel = self.get_share(&t.cs[0]); - let a = self.get_share(&t.cs[1]); - let b = self.get_share(&t.cs[2]); + let sel = self.get_share(&t.cs[0], to_share_type); + let a = self.get_share(&t.cs[1], to_share_type); + let b = self.get_share(&t.cs[2], to_share_type); let key = (t.op.clone(), vec![sel, a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op); self.bytecode_output.push(line); }; } @@ -452,17 +559,17 @@ impl<'a> ToABY<'a> { BvNaryOp::Add => "ADD", BvNaryOp::Mul => "MUL", }; - let a = self.get_share(&t.cs[0]); - let b = self.get_share(&t.cs[1]); + let a = self.get_share(&t.cs[0], to_share_type); + let b = self.get_share(&t.cs[1], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, b, s, op); self.bytecode_output.push(line); }; } @@ -478,22 +585,22 @@ impl<'a> ToABY<'a> { match o { BvBinOp::Sub | BvBinOp::Udiv | BvBinOp::Urem => { - let a = self.get_share(&t.cs[0]); - let b = self.get_share(&t.cs[1]); + let a = self.get_share(&t.cs[0], to_share_type); + let b = self.get_share(&t.cs[1], to_share_type); let key = (t.op.clone(), vec![a, b]); if self.cache.contains_key(&key) { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t, s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, b, s, op); self.bytecode_output.push(line); }; } BvBinOp::Shl | BvBinOp::Lshr => { - let a = self.get_share(&t.cs[0]); + let a = self.get_share(&t.cs[0], to_share_type); let const_shift_amount_term = fold(&t.cs[1], &[]); let const_shift_amount = const_shift_amount_term.as_bv_opt().unwrap().uint(); @@ -503,10 +610,9 @@ impl<'a> ToABY<'a> { let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t, s); } else { - let s = self.get_shares(&t); - self.cache.insert(key, s.clone()); - let line = - format!("2 1 {} {} {} {}\n", a, const_shift_amount, s[0], op); + let s = self.get_share(&t, to_share_type); + self.cache.insert(key, vec![s]); + let line = format!("2 1 {} {} {} {}\n", a, const_shift_amount, s, op); self.bytecode_output.push(line); }; } @@ -515,13 +621,13 @@ impl<'a> ToABY<'a> { } Op::Field(i) => { assert!(t.cs.len() == 1); - let shares = self.get_shares(&t.cs[0]); + let shares = self.get_shares(&t.cs[0], to_share_type); assert!(*i < shares.len()); self.term_to_shares.insert(t.clone(), vec![shares[*i]]); } Op::Select => { assert!(t.cs.len() == 2); - let array_shares = self.get_shares(&t.cs[0]); + let array_shares = self.get_shares(&t.cs[0], to_share_type); if let Op::Const(Value::BitVector(bv)) = &t.cs[1].op { let idx = bv.uint().to_usize().unwrap().clone(); @@ -537,8 +643,8 @@ impl<'a> ToABY<'a> { } else { let op = "SELECT"; let num_inputs = array_shares.len() + 1; - let index_share = self.get_share(&t.cs[1]); - let output = self.get_share(&t); + let index_share = self.get_share(&t.cs[1], to_share_type); + let output = self.get_share(&t, to_share_type); let line = format!( "{} 1 {} {} {} {}\n", num_inputs, @@ -556,6 +662,7 @@ impl<'a> ToABY<'a> { } fn embed_vector(&mut self, t: Term) { + let to_share_type = self.get_term_share_type(&t); match &t.op { Op::Const(Value::Array(arr)) => { let mut shares: Vec = Vec::new(); @@ -570,22 +677,26 @@ impl<'a> ToABY<'a> { // TODO: sort of value might not be a 32-bit bitvector let v_term = leaf_term(Op::Const(v.clone())); - if self.term_to_shares.contains_key(&v_term) { + if self.const_cache.contains_key(&v_term) { // existing const - let s = self.get_share(&v_term); + let s = self.get_share(&v_term, to_share_type); shares.push(s); } else { // new const - let s = self.get_new_share(&v_term, &t); - match v { - Value::BitVector(b) => { - let op = "CONS_bv"; - let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); - self.const_output.push(line); - } - _ => todo!(), - } + self.insert_const(&v_term); + let s = self.get_share(&v_term, to_share_type); shares.push(s); + + // let s = self.get_new_share(&v_term, &t); + // match v { + // Value::BitVector(b) => { + // let op = "CONS_bv"; + // let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); + // self.const_output.push(line); + // } + // _ => todo!(), + // } + // shares.push(s); } } @@ -593,26 +704,27 @@ impl<'a> ToABY<'a> { self.term_to_shares.insert(t.clone(), shares); } Op::Const(Value::Tuple(tup)) => { - let shares = self.get_shares(&t); - assert!(shares.len() == tup.len()); - for (val, s) in tup.iter().zip(shares.iter()) { - match val { - Value::BitVector(b) => { - let op = "CONS_bv"; - let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); - self.const_output.push(line); - } - _ => todo!(), - } - } + // let shares = self.get_shares(&t, to_share_type); + // assert!(shares.len() == tup.len()); + // for (val, s) in tup.iter().zip(shares.iter()) { + // match val { + // // Value::BitVector(b) => { + // // let op = "CONS_bv"; + // // let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op); + // // self.const_output.push(line); + // // } + // _ => todo!(), + // } + // } + todo!(); } Op::Ite => { let op = "MUX"; - let shares = self.get_shares(&t); + let shares = self.get_shares(&t, to_share_type); - let sel = self.get_share(&t.cs[0]); - let a = self.get_shares(&t.cs[1]); - let b = self.get_shares(&t.cs[2]); + let sel = self.get_share(&t.cs[0], to_share_type); + let a = self.get_shares(&t.cs[1], to_share_type); + let b = self.get_shares(&t.cs[2], to_share_type); // assert scalar_term share lens are equivalent assert!(shares.len() == a.len()); @@ -636,8 +748,8 @@ impl<'a> ToABY<'a> { } Op::Store => { assert!(t.cs.len() == 3); - let mut array_shares = self.get_shares(&t.cs[0]).clone(); - let value_share = self.get_share(&t.cs[2]); + let mut array_shares = self.get_shares(&t.cs[0], to_share_type).clone(); + let value_share = self.get_share(&t.cs[2], to_share_type); if let Op::Const(Value::BitVector(bv)) = &t.cs[1].op { // constant indexing @@ -647,9 +759,9 @@ impl<'a> ToABY<'a> { } else { let op = "STORE"; let num_inputs = array_shares.len() + 2; - let outputs = self.get_shares(&t); + let outputs = self.get_shares(&t, to_share_type); let num_outputs = outputs.len(); - let index_share = self.get_share(&t.cs[1]); + let index_share = self.get_share(&t.cs[1], to_share_type); let line = format!( "{} {} {} {} {} {} {}\n", num_inputs, @@ -666,7 +778,7 @@ impl<'a> ToABY<'a> { } Op::Field(i) => { assert!(t.cs.len() == 1); - let shares = self.get_shares(&t.cs[0]); + let shares = self.get_shares(&t.cs[0], to_share_type); let tuple_sort = check(&t.cs[0]); let (offset, len) = match tuple_sort { @@ -694,8 +806,8 @@ impl<'a> ToABY<'a> { } Op::Update(i) => { assert!(t.cs.len() == 2); - let mut tuple_shares = self.get_shares(&t.cs[0]); - let value_share = self.get_share(&t.cs[1]); + let mut tuple_shares = self.get_shares(&t.cs[0], to_share_type); + let value_share = self.get_share(&t.cs[1], to_share_type); // assert the index is in bounds assert!(*i < tuple_shares.len()); @@ -709,12 +821,12 @@ impl<'a> ToABY<'a> { Op::Tuple => { let mut shares: Vec = Vec::new(); for c in t.cs.iter() { - shares.append(&mut self.get_shares(c)); + shares.append(&mut self.get_shares(c, to_share_type)); } self.term_to_shares.insert(t.clone(), shares); } Op::Call(name, _arg_names, arg_sorts, ret_sorts) => { - let shares = self.get_shares(&t); + let shares = self.get_shares(&t, to_share_type); let op = format!("CALL({})", name); let num_args: usize = arg_sorts.iter().map(|ret| self.get_sort_len(ret)).sum(); let num_rets: usize = ret_sorts.iter().map(|ret| self.get_sort_len(ret)).sum(); @@ -722,9 +834,17 @@ impl<'a> ToABY<'a> { for c in t.cs.iter() { let sort = check(c); if self.rewirable(&sort) { - arg_shares.extend(self.get_shares(c).iter().map(|&s| s.to_string())) + arg_shares.extend( + self.get_shares(c, to_share_type) + .iter() + .map(|&s| s.to_string()), + ) } else { - arg_shares.extend(self.get_shares(c).iter().map(|&s| s.to_string())) + arg_shares.extend( + self.get_shares(c, to_share_type) + .iter() + .map(|&s| s.to_string()), + ) } } @@ -809,7 +929,8 @@ impl<'a> ToABY<'a> { self.embed(t.clone()); let op = "OUT"; - let shares = self.get_shares(&t); + let to_share_type = self.get_term_share_type(&t); + let shares = self.get_shares(&t, to_share_type); for s in shares { let line = format!("1 0 {} {}\n", s, op); @@ -895,9 +1016,9 @@ impl<'a> ToABY<'a> { /// Convert this (IR) `ir` to ABY. pub fn to_aby(ir: Functions, path: &Path, lang: &str, cm: &str, ss: &str) { // Call site similarity - // println!("call site"); - // call_site_similarity(&ir); - // println!("end call site"); + println!("call site"); + call_site_similarity(&ir); + println!("end call site"); // Protocal Assignments let mut s_map: HashMap = HashMap::new();