merge hack but still slow

This commit is contained in:
Clive2312
2022-07-25 03:02:00 +00:00
18 changed files with 578 additions and 380 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ __pycache__
.mode.txt
scripts/aby_tests/tests
/flamegraph*.svg
.vscode/

View File

@@ -1,11 +1,68 @@
#define LEN 32
#define NUM_REVIEWERS 1
#define NUM_RATINGS 1
#define NUM_REVIEWERS 100
#define NUM_RATINGS 100
#define INTERVALS 2
#define NUM_BUCKETS (INTERVALS * 5) - 1
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
/* returns val/mod, integer division */
// int quot(int val, int mod) {
// if (mod == 0){
// return val;
// } else{
// int rem = val % mod;
// return (val - rem) / mod;
// }
// }
int map(int sumRatings) {
int bucket = NUM_RATINGS+1;
int val = sumRatings;
int mod = NUM_RATINGS;
int absReview = val / mod;
int fraction = val % mod;
// int absReview = 2;
// int fraction = 3;
int m = INTERVALS * (absReview - 1);
int num = fraction * INTERVALS;
for (int j = 0; j < INTERVALS; j++) {
int low = j * NUM_RATINGS;
int high = (j + 1) * NUM_RATINGS;
int cond1;
if(low <= num) {
cond1 = 1;
}
else {
cond1 = 0;
}
int cond2;
if(high > num) {
cond2 = 1;
}
else {
cond2 = 0;
}
int cond = cond1 + cond2;
int newBucket;
if(cond == 2) {
newBucket = m + j;
}
else {
newBucket = bucket;
}
bucket = newBucket;
}
return bucket;
}
int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
{
int result[NUM_BUCKETS];
@@ -17,7 +74,7 @@ int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((priv
for (int j = 0; j < NUM_RATINGS; j++) {
sum = sum + reviews[i*NUM_RATINGS + j];
}
int bucket = sum;
int bucket = map(sum);
for (int j = 0; j < NUM_BUCKETS; j++) {
int temp;
if (j == bucket) {
@@ -34,4 +91,67 @@ int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((priv
sum_all += result[i];
}
return sum_all;
}
}
// int map(int sumRatings) {
// // int bucket = NUM_RATINGS+1;
// // int val = sumRatings;
// // int mod = NUM_RATINGS;
// // int absReview = val / mod;
// // int fraction = val % mod;
// // int absReview = 2;
// // int fraction = 3;
// // int absReview = sumRatings;
// // int fraction = sumRatings;
// int m = INTERVALS * (sumRatings - 1);
// int num = sumRatings * INTERVALS;
// // for (int j = 0; j < INTERVALS; j++) {
// // int low = j * NUM_RATINGS;
// // int high = (j + 1) * NUM_RATINGS;
// // int cond1;
// // if(low <= num) {
// // cond1 = 1;
// // }
// // else {
// // cond1 = 0;
// // }
// // int cond2;
// // if(high > num) {
// // cond2 = 1;
// // }
// // else {
// // cond2 = 0;
// // }
// // int cond = cond1 + cond2;
// // int newBucket;
// // if(cond == 2) {
// // newBucket = m + j;
// // }
// // else {
// // newBucket = bucket;
// // }
// // bucket = newBucket;
// // }
// return m + num;
// }
// int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
// {
// int sum_all = offset;
// for (int i = 0; i < NUM_REVIEWERS; i++) {
// int sum = reviews[i*NUM_RATINGS];
// int sum2 = reviews[0];
// int bucket = sum2*sum;
// sum_all += bucket;
// }
// return sum_all;
// }

View File

@@ -0,0 +1,82 @@
#define LEN 32
#define NUM_REVIEWERS 100
#define NUM_RATINGS 100
#define INTERVALS 2
#define NUM_BUCKETS (INTERVALS * 5) - 1
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
typedef struct
{
int result[NUM_BUCKETS];
} Output;
int map(int sumRatings) {
int bucket = NUM_RATINGS+1;
int val = sumRatings;
int mod = NUM_RATINGS;
int absReview = val / mod;
int fraction = val % mod;
// int absReview = 2;
// int fraction = 3;
int m = INTERVALS * (absReview - 1);
int num = fraction * INTERVALS;
for (int j = 0; j < INTERVALS; j++) {
int low = j * NUM_RATINGS;
int high = (j + 1) * NUM_RATINGS;
int cond1;
if(low <= num) {
cond1 = 1;
}
else {
cond1 = 0;
}
int cond2;
if(high > num) {
cond2 = 1;
}
else {
cond2 = 0;
}
int cond = cond1 + cond2;
int newBucket;
if(cond == 2) {
newBucket = m + j;
}
else {
newBucket = bucket;
}
bucket = newBucket;
}
return bucket;
}
Output main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
{
Output res;
for (int i = 0; i < NUM_REVIEWERS; i++) {
int sum = 0;
for (int j = 0; j < NUM_RATINGS; j++) {
sum = sum + reviews[i*NUM_RATINGS + j];
}
int bucket = map(sum);
for (int j = 0; j < NUM_BUCKETS; j++) {
int temp;
if (j == bucket) {
temp = res.result[j] + 1;
}
else {
temp = res.result[j];
}
res.result[j] = temp;
}
}
return res;
}

View File

@@ -1,19 +1,61 @@
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b)
#define LEN 32
#define NUM_REVIEWERS 1
#define NUM_RATINGS 1
#define INTERVALS 2
#define NUM_BUCKETS (INTERVALS * 5) - 1
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
{
int result[10];
for(int i = 0; i < 10; i++){
result[i] = 0;
}
int result[NUM_BUCKETS];
for(int i = 0; i < 10; i++){
result[i] += 1;
}
for (int i = 0; i < NUM_REVIEWERS; i++) {
int sum = 0;
for (int j = 0; j < NUM_RATINGS; j++) {
sum = sum + reviews[i*NUM_RATINGS + j];
}
int bucket = sum;
for (int j = 0; j < NUM_BUCKETS; j++) {
int temp;
if (j == bucket) {
temp = result[j] + 1;
}
else {
temp = result[j];
}
result[j] = temp;
}
}
int sum_all = offset;
for(int i = 0; i < NUM_BUCKETS; i++){
sum_all += result[i];
}
return sum_all;
}
int res = 0;
for(int i = 0; i < 10; i++){
res += result[i];
}
// int f(int a) {
// return a + 1;
// }
return res + a + b;
}
// int main( __attribute__((private(0))) int a, __attribute__((private(1))) int b)
// {
// // base input
// int c = f(a);
// int d = f(b);
// // add input
// int e = a + c;
// int g = b + d;
// int h = f(e);
// int i = f(g);
// // multiply input
// int j = a * h;
// int k = b * i;
// int l = f(j);
// int m = f(k);
// return l + m;
// }

View File

@@ -1,5 +1,5 @@
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) {
int index = a + b;
int arr[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
return arr[index];
int arr[10] = {0, a, 2, 3, 4, 5, 6, 7, 8, 9};
return arr[index] + arr[1];
}

View File

@@ -252,14 +252,14 @@ fn main() {
Opt::Sha,
Opt::ConstantFold(Box::new(ignore.clone())),
Opt::Flatten,
// The function call abstraction creates tuples
Opt::Tuple,
// // The function call abstraction creates tuples
// Opt::Tuple,
// Opt::Obliv,
// The obliv elim pass produces more tuples, that must be eliminated
// // The obliv elim pass produces more tuples, that must be eliminated
// Opt::Tuple,
// Opt::LinearScan,
// The linear scan pass produces more tuples, that must be eliminated
// Opt::Tuple,
Opt::Tuple,
Opt::ConstantFold(Box::new(ignore.clone())),
// Inline Function Calls
// Opt::Link,
@@ -300,7 +300,7 @@ fn main() {
// for (name, c) in &cs.computations {
// println!("name: {}", name);
// for t in c.terms_postorder() {
// println!("t: {}", t);
// println!("t: {}", t.op);
// }
// }

View File

@@ -4,5 +4,5 @@ from util import run_tests
from test_suite import *
if __name__ == "__main__":
tests = benchmark_tests
tests = pc_histogram_tests
run_tests('c', tests)

View File

@@ -28,7 +28,10 @@ if __name__ == "__main__":
gauss_tests + \
db_tests + \
mnist_tests + \
cryptonets_tests
cryptonets_tests + \
histogram_tests
tests = histogram_tests
tests = biomatch_tests

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -683,6 +683,21 @@ pg_tests = [
"./scripts/aby_tests/test_inputs/playground.txt",
]
]
histogram_tests = [
[
"histogram",
"histogram",
"./scripts/aby_tests/test_inputs/histogram.txt",
]
]
pc_histogram_tests = [
[
"2pc_histogram",
"2pc_histogram",
"./scripts/aby_tests/test_inputs/2pc_histogram.txt",
]
]
# ilp_benchmark_tests = [
# [

View File

@@ -66,9 +66,103 @@ function mpc_test_6 {
cpath=$2
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+y"
}
# mpc_test 2 ./examples/C/mpc/playground.c
# # build mpc arithmetic tests
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c
# # mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c
# # build nary arithmetic tests
# mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c
# # build bitwise tests
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c
# # build boolean tests
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c
# # build nary boolean tests
# mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c
# # build const tests
# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c
# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c
# # build if statement tests
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c
# # build shift tests
# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c
# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c
# # build div tests
# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c
# # build array tests
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c
# # build circ/compiler array tests
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c
# # build function tests
# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c
# # build struct tests
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c
# # build matrix tests
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c
# # build ptr tests
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c
# # build misc tests
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c
# # build hycc benchmarks
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c
# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
# mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join2.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_merge.c
# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist.c
# mpc_test 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets.c
# # build OPA benchmarks
mpc_test_4 2 ./examples/C/mpc/benchmarks/histogram/2pc_histogram.c
mpc_test_3 2 ./examples/C/mpc/playground.c
# mpc_test_3 2 ./examples/C/mpc/playground.c
# mpc_test_4 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch_.c
# mpc_test_4 2 ./examples/C/mpc/benchmarks/db/db_join2.c
# mpc_test_4 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c

View File

@@ -5,7 +5,6 @@ use crate::ir::term::*;
/// Binarize cache.
#[derive(Default)]
struct Binarizer;
impl RewritePass for Binarizer {
@@ -26,13 +25,13 @@ impl RewritePass for Binarizer {
|acc, x| term![orig.op.clone(); x.clone(), acc],
))
}
},
_ => None
}
_ => None,
}
}
}
/// Binarize (expand) n-ary terms.
/// Binarize (expand) n-ary terms.
pub fn binarize(c: &mut Computation) {
let mut pass = Binarizer;
pass.traverse(c);

View File

@@ -1,52 +0,0 @@
//! Call Site Similarity
use crate::ir::term::*;
/// Determine if call sites are similar based on input and output arguments to the call site
pub fn call_site_similarity(fs: &mut Functions) {
// Return a TermMap of (call) --> id for which calls are similar
// Maybe return a vector of vector of terms
// Map of Vec<input: Vec<Term>, output: Vec<Term>> --> Vec<Call Term>
// For each call site, (input: Vec<Term>, output: Vec<Term>) -> Term (call)
let mut call_sites: TermMap<(Vec<Term>, Vec<Term>)> = TermMap::new();
for (name, comp) in fs.computations {
// Post order traversal through each computation
for t in comp.terms_postorder() {
match t.op {
Op::Call(name, arg_names, arg_sorts, ret_sorts) => {
let input: Vec<Term> = t.cs;
let output: Vec<Term> = Vec::new();
call_sites.insert(t.clone(), (input, output));
}
_ => {
// see if the call term was used as an argument in another term
for c in t.cs {
if call_sites.contains_key(&c) {
call_sites.get(c).1.push(t.clone());
}
}
}
}
}
// For each call term
// Get a list of inputs and output terms based on mutation step size
// Store input terms into a data structure (vec?)
// Store output terms into a data structure (vec?)
// Order terms by operator
// Create key: Vec<input: Vec<Term>, output: Vec<Term>>
// loop through existing call terms:
// longest prefix matching (edit distance?)
// if match:
// append to vec
// if no match:
// add as new entry
}
}

View File

@@ -319,6 +319,18 @@ impl Display for Op {
}
}
impl Ord for Op {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.cmp(&other)
}
}
impl PartialOrd for Op {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
/// Boolean n-ary operator
pub enum BoolNaryOp {
@@ -1283,7 +1295,7 @@ impl Value {
*b
} else {
panic!("Not a bool: {}", self)
}
}
}
#[track_caller]
/// Get the underlying bit-vector constant, or panic!
@@ -2040,17 +2052,6 @@ impl Computation {
}
}
// #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
// /// A function definition.
// pub struct FuncDef {
// /// Name of function
// pub name: String,
// /// Type signature of function parameters
// pub params: BTreeMap<String, Sort>,
// /// Return type of function
// pub ret_ty: Vec<Sort>,
// }
#[derive(Clone, Debug, Default, PartialEq)]
/// A map of IR computations.
pub struct Functions {

View File

@@ -0,0 +1,84 @@
//! Call Site Similarity
use crate::ir::term::*;
use std::collections::HashMap;
/// Determine if call sites are similar based on input and output arguments to the call site
pub fn call_site_similarity(fs: &Functions) -> Vec<Vec<Term>> {
// Return a TermMap of (call) --> id for which calls are similar
// Maybe return a vector of vector of terms
// Map of Vec<input: Vec<Term>, output: Vec<Term>> --> Vec<Call Term>
// map call Term -> (input: Vec<Term>, output: Vec<Term>)
let mut call_term_map: TermMap<(Vec<Term>, Vec<Term>)> = TermMap::new();
// map field(i) Term to parent call Term
let mut field_of_calls: TermMap<Term> = TermMap::new();
for (_name, comp) in &fs.computations {
for t in comp.terms_postorder() {
// see if the call term was used as an argument in another term
for c in &t.cs {
if call_term_map.contains_key(c) {
field_of_calls.insert(t.clone(), c.clone());
}
if field_of_calls.contains_key(c) {
let call_term = field_of_calls.get(c).unwrap();
call_term_map.get_mut(call_term).unwrap().1.push(t.clone());
}
}
match &t.op {
Op::Call(..) => {
let input: Vec<Term> = t.cs.clone();
let output: Vec<Term> = Vec::new();
call_term_map.insert(t.clone(), (input, output));
}
_ => {
// do nothing
}
}
}
// For each call term
// Get a list of inputs and output terms based on mutation step size
// Store input terms into a data structure (vec?)
// Store output terms into a data structure (vec?)
// Order terms by operator
// Create key: Vec<input: Vec<Term>, output: Vec<Term>>
// SORT OUTPUT TERMS
// loop through existing call terms:
// longest prefix matching (edit distance?)
// if match:
// append to vec
// if no match:
// add as new entry
}
// Clean input and output terms
let mut call_sites: HashMap<(Vec<Op>, Vec<Op>), Vec<Term>> = HashMap::new();
for (c, (i, o)) in call_term_map {
let input_ops = i.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
let mut output_ops = o.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
output_ops.sort();
let key = (input_ops, output_ops);
// longest prefix matching?
// edit distance?
if call_sites.contains_key(&key) {
call_sites.get_mut(&key).unwrap().push(c);
} else {
call_sites.insert(key, vec![c]);
}
}
return call_sites.into_values().collect::<Vec<_>>();
}

View File

@@ -1,4 +1,5 @@
//! ABY
pub mod assignment;
pub mod call_site_similarity;
pub mod trans;
pub mod utils;

View File

@@ -18,7 +18,6 @@ use std::fmt;
use std::fs;
use std::io;
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "lp")]
use crate::target::graph::trans::*;
@@ -29,6 +28,8 @@ use super::assignment::assign_arithmetic_and_boolean;
use super::assignment::assign_arithmetic_and_yao;
use super::assignment::assign_greedy;
// use super::call_site_similarity::call_site_similarity;
const PUBLIC: u8 = 2;
const WRITE_SIZE: usize = 65536;
@@ -201,11 +202,11 @@ struct ToABY<'a> {
curr_comp: String,
// Input mapping
inputs: Vec<Term>,
// Term cache
cache: TermMap<EmbeddedTerm>,
// Term to share id
term_to_shares: TermMap<Vec<i32>>,
share_cnt: i32,
// Cache
cache: HashMap<(Op, Vec<i32>), Vec<i32>>,
// Outputs
bytecode_input: Vec<String>,
bytecode_output: Vec<String>,
@@ -221,7 +222,6 @@ impl Drop for ToABY<'_> {
// drop everything that uses a Term
// drop(take(&mut self.md));
self.inputs.clear();
self.cache.clear();
self.term_to_shares.clear();
// self.s_map.clear();
// clean up
@@ -238,9 +238,9 @@ impl<'a> ToABY<'a> {
lang: lang.to_string(),
curr_comp: "".to_string(),
inputs: Vec::new(),
cache: TermMap::new(),
term_to_shares: TermMap::new(),
share_cnt: 0,
cache: HashMap::new(),
bytecode_input: Vec::new(),
bytecode_output: Vec::new(),
const_output: Vec::new(),
@@ -314,20 +314,6 @@ impl<'a> ToABY<'a> {
}
}
fn get_var_name_from_term(t: &Term) -> String {
match &t.op {
Op::Var(name, _) => ToABY::get_var_name(name),
_ => panic!("Term {} is not of type Var", t),
}
}
fn get_sharing_map(&mut self, name: &str) -> SharingMap {
match self.s_map.get(name) {
Some(s) => s.clone(),
None => panic!("Unknown sharing map for function: {}", name),
}
}
fn write_share(&mut self, t: &Term, s: i32) {
if !self.written_const_set.contains(&s){
let s_map = self.s_map.get(&self.curr_comp).unwrap();
@@ -525,38 +511,13 @@ impl<'a> ToABY<'a> {
}
}
fn embed_eq(&mut self, t: Term, a_term: Term, b_term: Term) {
let s = self.get_share(&t);
fn embed_eq(&mut self, t: &Term) {
let s = self.get_share(t);
let a = self.get_share(&t.cs[0]);
let b = self.get_share(&t.cs[1]);
let op = "EQ";
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
self.bytecode_output.push(line);
match check(&a_term) {
Sort::Bool => {
self.check_bool(&a_term);
self.check_bool(&b_term);
self.cache.insert(t, EmbeddedTerm::Bool);
}
Sort::BitVector(_) => {
self.check_bv(&a_term);
self.check_bv(&b_term);
self.cache.insert(t, EmbeddedTerm::Bool);
}
e => panic!("Unimplemented sort for Eq: {:?}", e),
}
}
/// Given term `t`, type-check `t` is of type Bool
fn check_bool(&self, t: &Term) {
match self
.cache
.get(t)
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
{
EmbeddedTerm::Bool => (),
_ => panic!("Non-bool for {:?}", t),
}
}
fn embed_bool(&mut self, t: Term) {
@@ -565,7 +526,7 @@ impl<'a> ToABY<'a> {
Op::Var(name, Sort::Bool) => {
let md = self.get_md();
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
let term_name = ToABY::get_var_name_from_term(&t);
let term_name = ToABY::get_var_name(&name);
let vis = self.unwrap_vis(name);
let s = self.get_share(&t);
let op = "IN";
@@ -580,46 +541,31 @@ impl<'a> ToABY<'a> {
}
self.inputs.push(t.clone());
}
if !self.cache.contains_key(&t) {
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
}
Op::Const(Value::Bool(b)) => {
let op = "CONS_bool";
let line = format!("1 1 {} {} {}\n", *b as i32, s, op);
self.const_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
Op::Eq => {
self.embed_eq(t.clone(), t.cs[0].clone(), t.cs[1].clone());
self.embed_eq(&t);
}
Op::Ite => {
let op = "MUX";
self.check_bool(&t.cs[0]);
self.check_bool(&t.cs[1]);
self.check_bool(&t.cs[2]);
let sel = self.get_share(&t.cs[0]);
let a = self.get_share(&t.cs[1]);
let b = self.get_share(&t.cs[2]);
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
Op::Not => {
let op = "NOT";
self.check_bool(&t.cs[0]);
let a = self.get_share(&t.cs[0]);
let line = format!("1 1 {} {} {}\n", a, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
Op::BoolNaryOp(o) => {
if t.cs.len() == 1 {
@@ -627,7 +573,6 @@ impl<'a> ToABY<'a> {
// If t.cs len is 1, just output that term
// This is to bypass adding an AND gate with a single conditional term
// Refer to pub fn condition() in src/circify/mod.rs
self.check_bool(&t.cs[0]);
let a = self.get_share(&t.cs[0]);
match o {
BoolNaryOp::And => self.term_to_shares.insert(t.clone(), vec![a]),
@@ -635,11 +580,7 @@ impl<'a> ToABY<'a> {
unimplemented!("Single operand boolean operation");
}
};
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
} else {
self.check_bool(&t.cs[0]);
self.check_bool(&t.cs[1]);
let op = match o {
BoolNaryOp::Or => "OR",
BoolNaryOp::And => "AND",
@@ -650,8 +591,6 @@ impl<'a> ToABY<'a> {
let b = self.get_share(&t.cs[1]);
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
}
Op::BvBinPred(o) => {
@@ -663,38 +602,21 @@ impl<'a> ToABY<'a> {
_ => panic!("Non-field in bool BvBinPred: {}", o),
};
self.check_bv(&t.cs[0]);
self.check_bv(&t.cs[1]);
let a = self.get_share(&t.cs[0]);
let b = self.get_share(&t.cs[1]);
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
}
_ => panic!("Non-field in embed_bool: {}", t),
}
}
/// Given term `t`, type-check `t` is of type Bv
fn check_bv(&self, t: &Term) {
match self
.cache
.get(t)
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
{
EmbeddedTerm::Bv => (),
_ => panic!("Non-bv for {:?}", t),
}
}
fn embed_bv(&mut self, t: Term) {
match &t.op {
Op::Var(name, Sort::BitVector(_)) => {
let md = self.get_md();
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
let term_name = ToABY::get_var_name_from_term(&t);
let term_name = ToABY::get_var_name(&name);
let vis = self.unwrap_vis(name);
let s = self.get_share(&t);
let op = "IN";
@@ -709,10 +631,6 @@ impl<'a> ToABY<'a> {
}
self.inputs.push(t.clone());
}
if !self.cache.contains_key(&t) {
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
}
Op::Const(Value::BitVector(b)) => {
let s = self.get_share(&t);
@@ -725,27 +643,26 @@ impl<'a> ToABY<'a> {
}
self.const_output.push(line);
}
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
// self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
Op::Ite => {
let s = self.get_share(&t);
let op = "MUX";
self.check_bool(&t.cs[0]);
self.check_bv(&t.cs[1]);
self.check_bv(&t.cs[2]);
let sel = self.get_share(&t.cs[0]);
let a = self.get_share(&t.cs[1]);
let b = self.get_share(&t.cs[2]);
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
let key = (t.op.clone(), vec![sel, a, b]);
if self.cache.contains_key(&key) {
let s = self.cache.get(&key).unwrap().clone();
self.term_to_shares.insert(t.clone(), s);
} else {
let s = self.get_shares(&t);
self.cache.insert(key, s.clone());
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op);
self.bytecode_output.push(line);
};
}
Op::BvNaryOp(o) => {
let s = self.get_share(&t);
let op = match o {
BvNaryOp::Xor => "XOR",
BvNaryOp::Or => "OR",
@@ -753,20 +670,21 @@ impl<'a> ToABY<'a> {
BvNaryOp::Add => "ADD",
BvNaryOp::Mul => "MUL",
};
self.check_bv(&t.cs[0]);
self.check_bv(&t.cs[1]);
let a = self.get_share(&t.cs[0]);
let b = self.get_share(&t.cs[1]);
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
let key = (t.op.clone(), vec![a, b]);
if self.cache.contains_key(&key) {
let s = self.cache.get(&key).unwrap().clone();
self.term_to_shares.insert(t.clone(), s);
} else {
let s = self.get_shares(&t);
self.cache.insert(key, s.clone());
let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op);
self.bytecode_output.push(line);
};
}
Op::BvBinOp(o) => {
let s = self.get_share(&t);
let op = match o {
BvBinOp::Sub => "SUB",
BvBinOp::Udiv => "DIV",
@@ -778,30 +696,37 @@ impl<'a> ToABY<'a> {
match o {
BvBinOp::Sub | BvBinOp::Udiv | BvBinOp::Urem => {
self.check_bv(&t.cs[0]);
self.check_bv(&t.cs[1]);
let a = self.get_share(&t.cs[0]);
let b = self.get_share(&t.cs[1]);
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
let key = (t.op.clone(), vec![a, b]);
if self.cache.contains_key(&key) {
let s = self.cache.get(&key).unwrap().clone();
self.term_to_shares.insert(t, s);
} else {
let s = self.get_shares(&t);
self.cache.insert(key, s.clone());
let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op);
self.bytecode_output.push(line);
};
}
BvBinOp::Shl | BvBinOp::Lshr => {
self.check_bv(&t.cs[0]);
self.check_bv(&t.cs[1]);
let a = self.get_share(&t.cs[0]);
let const_shift_amount_term = fold(&t.cs[1], &[]);
let const_shift_amount =
const_shift_amount_term.as_bv_opt().unwrap().uint();
let line = format!("2 1 {} {} {} {}\n", a, const_shift_amount, s, op);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
let key = (t.op.clone(), vec![a, const_shift_amount.to_i32().unwrap()]);
if self.cache.contains_key(&key) {
let s = self.cache.get(&key).unwrap().clone();
self.term_to_shares.insert(t, s);
} else {
let s = self.get_shares(&t);
self.cache.insert(key, s.clone());
let line =
format!("2 1 {} {} {} {}\n", a, const_shift_amount, s[0], op);
self.bytecode_output.push(line);
};
}
_ => panic!("Binop not supported: {}", o),
};
@@ -811,7 +736,6 @@ impl<'a> ToABY<'a> {
let shares = self.get_shares(&t.cs[0]);
assert!(*i < shares.len());
self.term_to_shares.insert(t.clone(), vec![shares[*i]]);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
Op::Select => {
assert!(t.cs.len() == 2);
@@ -828,7 +752,6 @@ impl<'a> ToABY<'a> {
self.term_to_shares
.insert(t.clone(), vec![array_shares[idx]]);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
} else {
let op = "SELECT";
let num_inputs = array_shares.len() + 1;
@@ -844,7 +767,6 @@ impl<'a> ToABY<'a> {
);
self.bytecode_output.push(line);
self.term_to_shares.insert(t.clone(), vec![output]);
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
}
_ => panic!("Non-field in embed_bv: {:?}", t),
@@ -852,42 +774,51 @@ impl<'a> ToABY<'a> {
}
fn embed_scalar(&mut self, t: Term) {
let now = Instant::now();
match &t.op {
Op::Const(Value::Array(arr)) => {
let shares = self.get_shares(&t);
assert!(shares.len() == arr.size);
// let shares = self.get_shares(&t);
// assert!(shares.len() == arr.size);
for (i, s) in shares.iter().enumerate() {
let mut shares: Vec<i32> = Vec::new();
for i in 0..arr.size {
// TODO: sort of index might not be a 32-bit bitvector
let idx = Value::BitVector(BitVector::new(Integer::from(i), 32));
let v = match arr.map.get(&idx) {
Some(c) => c,
None => &*arr.default,
};
match v {
Value::BitVector(b) => {
if !self.written_const_set.contains(s){
self.written_const_set.insert(*s);
let op = "CONS_bv";
let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op);
if b.as_sint() == 99{
println!("GOtcha2: {}", t);
// TODO: sort of value might not be a 32-bit bitvector
let v_term = leaf_term(Op::Const(v.clone()));
if self.term_to_shares.contains_key(&v_term) {
// existing const
let s = self.get_share(&v_term);
shares.push(s);
} else {
// new const
let s = self.get_share(&v_term);
match v {
Value::BitVector(b) => {
if !self.written_const_set.contains(&s){
self.written_const_set.insert(s);
let op = "CONS_bv";
let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op);
if b.as_sint() == 99{
println!("GOtcha2: {}", t);
}
self.const_output.push(line);
}
self.const_output.push(line);
// self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
_ => todo!(),
}
_ => todo!(),
shares.push(s);
}
}
unsafe {
num_const_arr += 1;
dur_const_arr += now.elapsed();
};
assert!(shares.len() == arr.size);
self.term_to_shares.insert(t.clone(), shares);
}
Op::Const(Value::Tuple(tup)) => {
let shares = self.get_shares(&t);
@@ -904,16 +835,10 @@ impl<'a> ToABY<'a> {
}
self.const_output.push(line);
}
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
}
_ => todo!(),
}
}
unsafe {
num_const_tuple += 1;
dur_const_tuple += now.elapsed();
};
}
Op::Ite => {
let op = "MUX";
@@ -942,13 +867,6 @@ impl<'a> ToABY<'a> {
);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Array);
unsafe {
num_ite += 1;
dur_ite += now.elapsed();
};
}
Op::Store => {
assert!(t.cs.len() == 3);
@@ -962,7 +880,6 @@ impl<'a> ToABY<'a> {
array_shares[idx] = value_share;
self.term_to_shares.insert(t.clone(), array_shares.clone());
self.cache.insert(t.clone(), EmbeddedTerm::Array);
} else {
let op = "STORE";
let num_inputs = array_shares.len() + 2;
@@ -982,11 +899,6 @@ impl<'a> ToABY<'a> {
self.bytecode_output.push(line);
}
unsafe {
num_store += 1;
dur_store += now.elapsed();
};
}
Op::Field(i) => {
assert!(t.cs.len() == 1);
@@ -1014,12 +926,6 @@ impl<'a> ToABY<'a> {
let field_shares = &shares[offset..offset + len];
self.term_to_shares.insert(t.clone(), field_shares.to_vec());
self.cache.insert(t.clone(), EmbeddedTerm::Array);
unsafe {
num_field += 1;
dur_field += now.elapsed();
};
}
Op::Update(i) => {
assert!(t.cs.len() == 2);
@@ -1034,12 +940,6 @@ impl<'a> ToABY<'a> {
// store shares
self.term_to_shares.insert(t.clone(), tuple_shares);
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
unsafe {
num_update += 1;
dur_update += now.elapsed();
};
}
Op::Tuple => {
let mut shares: Vec<i32> = Vec::new();
@@ -1047,12 +947,6 @@ impl<'a> ToABY<'a> {
shares.append(&mut self.get_shares(c));
}
self.term_to_shares.insert(t.clone(), shares);
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
unsafe {
num_tuple += 1;
dur_tuple += now.elapsed();
};
}
Op::Call(name, _arg_names, arg_sorts, ret_sorts) => {
let shares = self.get_shares(&t);
@@ -1093,12 +987,6 @@ impl<'a> ToABY<'a> {
op
);
self.bytecode_output.push(line);
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
unsafe {
num_call += 1;
dur_call += now.elapsed();
};
}
_ => {
panic!("Non-field in embed_scalar: {}", t.op)
@@ -1144,88 +1032,25 @@ impl<'a> ToABY<'a> {
let mut write_time: std::time::Duration = std::time::Duration::new(0, 0);
for c in PostOrderIter_v2::new(t) {
if self.cache.contains_key(&c) {
if self.term_to_shares.contains_key(&c) {
continue;
}
let b_now = Instant::now(); // check for tuples are long
match check(&c) {
Sort::Bool => {
let now = Instant::now();
self.embed_bool(c);
num_bool += 1;
dur_bool += now.elapsed();
}
Sort::BitVector(_) => {
let now = Instant::now();
self.embed_bv(c);
num_bv += 1;
dur_bv += now.elapsed();
}
Sort::Array(..) | Sort::Tuple(_) => {
let now = Instant::now();
self.embed_scalar(c);
num_scalar += 1;
dur_scalar += now.elapsed();
}
e => panic!("Unsupported sort in embed: {:?}", e),
}
let now = Instant::now();
self.write_bytecode_output(false);
self.write_const_output(false);
self.write_share_output(false);
write_time += now.elapsed();
}
println!("bool: {}, bv: {}, scalar: {}", num_bool, num_bv, num_scalar);
println!(
"times: bool: {:?}, bv: {:?}, scalar: {:?}",
dur_bool, dur_bv, dur_scalar
);
println!("write time: {:?}", write_time);
if num_bool > 0 && num_bv > 0 && num_scalar > 0 {
println!(
"norm_times: bool: {:?}, bv: {:?}, scalar: {:?}\n",
dur_bool / num_bool,
dur_bv / num_bv,
dur_scalar / num_scalar
);
}
unsafe {
println!("================================");
println!("const_arr: {}, const_tuple: {}, ite: {}, store: {}, field: {}, update: {}, tuple: {}, call: {}", num_const_arr, num_const_tuple, num_ite, num_store, num_field, num_update, num_tuple, num_call);
println!("times: const_arr: {:?}, const_tuple: {:?}, ite: {:?}, store: {:?}, field: {:?}, update: {:?}, tuple: {:?}, call: {:?}", dur_const_arr, dur_const_tuple, dur_ite, dur_store, dur_field, dur_update, dur_tuple, dur_call);
if num_const_arr > 0 {
println!("norm_const_arr: {:?}", dur_const_arr / num_const_arr as u32);
}
if num_const_tuple > 0 {
println!(
"norm_const_tuple: {:?}",
dur_const_tuple / num_const_tuple as u32
);
}
if num_ite > 0 {
println!("norm_ite: {:?}", dur_ite / num_ite as u32);
}
if num_store > 0 {
println!("norm_store: {:?}", dur_store / num_store as u32);
}
if num_field > 0 {
println!("norm_field: {:?}", dur_field / num_field as u32);
}
if num_update > 0 {
println!("norm_update: {:?}", dur_update / num_update as u32);
}
if num_tuple > 0 {
println!("norm_tuple: {:?}", dur_tuple / num_tuple as u32);
}
if num_call > 0 {
println!("norm_call: {:?}", dur_call / num_call as u32);
}
println!("================================\n")
}
}
@@ -1246,7 +1071,6 @@ impl<'a> ToABY<'a> {
for (name, comp) in computations.iter() {
let mut outputs: Vec<String> = Vec::new();
let mut now = Instant::now();
// set current computation
self.curr_comp = name.to_string();
@@ -1274,10 +1098,6 @@ impl<'a> ToABY<'a> {
}
self.bytecode_output.append(&mut outputs);
println!("Time: lowering {}: {:?}", name, now.elapsed());
now = Instant::now();
// reorder inputs
let mut bytecode_input_map: HashMap<String, String> = HashMap::new();
for line in &self.bytecode_input {
@@ -1304,9 +1124,6 @@ impl<'a> ToABY<'a> {
.filter(|x| !x.is_empty())
.collect::<Vec<String>>();
self.bytecode_input = inputs;
println!("Time: reordering inputs {}: {:?}", name, now.elapsed());
now = Instant::now();
// write input bytecode
let bytecode_path =
@@ -1322,8 +1139,6 @@ impl<'a> ToABY<'a> {
);
write_lines(&bytecode_output_path, &self.bytecode_output);
println!("Time: writing {}: {:?}", name, now.elapsed());
// combine input and output bytecode files into a single file
let mut bytecode = fs::OpenOptions::new()
.append(true)
@@ -1347,7 +1162,6 @@ impl<'a> ToABY<'a> {
self.bytecode_input.clear();
self.bytecode_output.clear();
self.inputs.clear();
self.cache.clear();
}
// write remaining const variables
@@ -1356,15 +1170,6 @@ impl<'a> ToABY<'a> {
// write remaining shares
self.write_share_output(true);
}
fn convert(&mut self) {
let mut now = Instant::now();
// self.map_to_shares();
// println!("Time: map terms to shares: {:?}", now.elapsed());
now = Instant::now();
self.lower();
println!("Time: lowering: {:?}", now.elapsed());
}
}
/// Convert this (IR) `ir` to ABY.
@@ -1388,19 +1193,19 @@ pub fn to_aby(
"gglp" => {
let (fs, s_map) = inline_all_and_assign_glp(&ir, cm);
let mut converter = ToABY::new(fs, s_map, path, lang);
converter.convert();
converter.lower();
}
#[cfg(feature = "lp")]
"lp+mut" => {
let (fs, s_map) = partition_with_mut(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
let mut converter = ToABY::new(fs, s_map, path, lang);
converter.convert();
converter.lower();
}
// #[cfg(feature = "lp")]
// "mlp+mut" => {
// let (fs, s_map) = mlp_with_mut(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
// let mut converter = ToABY::new(fs, s_map, path, lang);
// converter.convert();
// converter.lower();
// }
_ =>{
// Protocal Assignments
@@ -1424,7 +1229,7 @@ pub fn to_aby(
s_map.insert(name.to_string(), assignments);
}
let mut converter = ToABY::new(ir, s_map, path, lang);
converter.convert();
converter.lower();
}
};