mirror of
https://github.com/circify/circ.git
synced 2026-04-21 03:00:54 -04:00
merge hack but still slow
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ __pycache__
|
||||
.mode.txt
|
||||
scripts/aby_tests/tests
|
||||
/flamegraph*.svg
|
||||
.vscode/
|
||||
|
||||
@@ -1,11 +1,68 @@
|
||||
#define LEN 32
|
||||
#define NUM_REVIEWERS 1
|
||||
#define NUM_RATINGS 1
|
||||
#define NUM_REVIEWERS 100
|
||||
#define NUM_RATINGS 100
|
||||
#define INTERVALS 2
|
||||
#define NUM_BUCKETS (INTERVALS * 5) - 1
|
||||
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
|
||||
|
||||
|
||||
/* returns val/mod, integer division */
|
||||
// int quot(int val, int mod) {
|
||||
// if (mod == 0){
|
||||
// return val;
|
||||
// } else{
|
||||
// int rem = val % mod;
|
||||
// return (val - rem) / mod;
|
||||
// }
|
||||
// }
|
||||
|
||||
int map(int sumRatings) {
|
||||
|
||||
int bucket = NUM_RATINGS+1;
|
||||
|
||||
int val = sumRatings;
|
||||
int mod = NUM_RATINGS;
|
||||
|
||||
int absReview = val / mod;
|
||||
int fraction = val % mod;
|
||||
// int absReview = 2;
|
||||
// int fraction = 3;
|
||||
|
||||
int m = INTERVALS * (absReview - 1);
|
||||
int num = fraction * INTERVALS;
|
||||
for (int j = 0; j < INTERVALS; j++) {
|
||||
int low = j * NUM_RATINGS;
|
||||
int high = (j + 1) * NUM_RATINGS;
|
||||
int cond1;
|
||||
if(low <= num) {
|
||||
cond1 = 1;
|
||||
}
|
||||
else {
|
||||
cond1 = 0;
|
||||
}
|
||||
int cond2;
|
||||
if(high > num) {
|
||||
cond2 = 1;
|
||||
}
|
||||
else {
|
||||
cond2 = 0;
|
||||
}
|
||||
int cond = cond1 + cond2;
|
||||
|
||||
int newBucket;
|
||||
if(cond == 2) {
|
||||
newBucket = m + j;
|
||||
}
|
||||
else {
|
||||
newBucket = bucket;
|
||||
}
|
||||
|
||||
bucket = newBucket;
|
||||
}
|
||||
|
||||
return bucket;
|
||||
}
|
||||
|
||||
int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
|
||||
{
|
||||
int result[NUM_BUCKETS];
|
||||
@@ -17,7 +74,7 @@ int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((priv
|
||||
for (int j = 0; j < NUM_RATINGS; j++) {
|
||||
sum = sum + reviews[i*NUM_RATINGS + j];
|
||||
}
|
||||
int bucket = sum;
|
||||
int bucket = map(sum);
|
||||
for (int j = 0; j < NUM_BUCKETS; j++) {
|
||||
int temp;
|
||||
if (j == bucket) {
|
||||
@@ -34,4 +91,67 @@ int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((priv
|
||||
sum_all += result[i];
|
||||
}
|
||||
return sum_all;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// int map(int sumRatings) {
|
||||
|
||||
// // int bucket = NUM_RATINGS+1;
|
||||
|
||||
// // int val = sumRatings;
|
||||
// // int mod = NUM_RATINGS;
|
||||
|
||||
// // int absReview = val / mod;
|
||||
// // int fraction = val % mod;
|
||||
// // int absReview = 2;
|
||||
// // int fraction = 3;
|
||||
// // int absReview = sumRatings;
|
||||
// // int fraction = sumRatings;
|
||||
|
||||
// int m = INTERVALS * (sumRatings - 1);
|
||||
// int num = sumRatings * INTERVALS;
|
||||
// // for (int j = 0; j < INTERVALS; j++) {
|
||||
// // int low = j * NUM_RATINGS;
|
||||
// // int high = (j + 1) * NUM_RATINGS;
|
||||
// // int cond1;
|
||||
// // if(low <= num) {
|
||||
// // cond1 = 1;
|
||||
// // }
|
||||
// // else {
|
||||
// // cond1 = 0;
|
||||
// // }
|
||||
// // int cond2;
|
||||
// // if(high > num) {
|
||||
// // cond2 = 1;
|
||||
// // }
|
||||
// // else {
|
||||
// // cond2 = 0;
|
||||
// // }
|
||||
// // int cond = cond1 + cond2;
|
||||
|
||||
// // int newBucket;
|
||||
// // if(cond == 2) {
|
||||
// // newBucket = m + j;
|
||||
// // }
|
||||
// // else {
|
||||
// // newBucket = bucket;
|
||||
// // }
|
||||
|
||||
// // bucket = newBucket;
|
||||
// // }
|
||||
|
||||
// return m + num;
|
||||
// }
|
||||
|
||||
// int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
|
||||
// {
|
||||
// int sum_all = offset;
|
||||
// for (int i = 0; i < NUM_REVIEWERS; i++) {
|
||||
// int sum = reviews[i*NUM_RATINGS];
|
||||
// int sum2 = reviews[0];
|
||||
// int bucket = sum2*sum;
|
||||
// sum_all += bucket;
|
||||
// }
|
||||
|
||||
// return sum_all;
|
||||
// }
|
||||
82
examples/C/mpc/benchmarks/histogram/histogram.c
Normal file
82
examples/C/mpc/benchmarks/histogram/histogram.c
Normal file
@@ -0,0 +1,82 @@
|
||||
#define LEN 32
|
||||
#define NUM_REVIEWERS 100
|
||||
#define NUM_RATINGS 100
|
||||
#define INTERVALS 2
|
||||
#define NUM_BUCKETS (INTERVALS * 5) - 1
|
||||
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
|
||||
|
||||
typedef struct
|
||||
{
|
||||
int result[NUM_BUCKETS];
|
||||
} Output;
|
||||
|
||||
int map(int sumRatings) {
|
||||
|
||||
int bucket = NUM_RATINGS+1;
|
||||
|
||||
int val = sumRatings;
|
||||
int mod = NUM_RATINGS;
|
||||
|
||||
int absReview = val / mod;
|
||||
int fraction = val % mod;
|
||||
// int absReview = 2;
|
||||
// int fraction = 3;
|
||||
|
||||
int m = INTERVALS * (absReview - 1);
|
||||
int num = fraction * INTERVALS;
|
||||
for (int j = 0; j < INTERVALS; j++) {
|
||||
int low = j * NUM_RATINGS;
|
||||
int high = (j + 1) * NUM_RATINGS;
|
||||
int cond1;
|
||||
if(low <= num) {
|
||||
cond1 = 1;
|
||||
}
|
||||
else {
|
||||
cond1 = 0;
|
||||
}
|
||||
int cond2;
|
||||
if(high > num) {
|
||||
cond2 = 1;
|
||||
}
|
||||
else {
|
||||
cond2 = 0;
|
||||
}
|
||||
int cond = cond1 + cond2;
|
||||
|
||||
int newBucket;
|
||||
if(cond == 2) {
|
||||
newBucket = m + j;
|
||||
}
|
||||
else {
|
||||
newBucket = bucket;
|
||||
}
|
||||
|
||||
bucket = newBucket;
|
||||
}
|
||||
|
||||
return bucket;
|
||||
}
|
||||
|
||||
Output main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
|
||||
{
|
||||
Output res;
|
||||
|
||||
for (int i = 0; i < NUM_REVIEWERS; i++) {
|
||||
int sum = 0;
|
||||
for (int j = 0; j < NUM_RATINGS; j++) {
|
||||
sum = sum + reviews[i*NUM_RATINGS + j];
|
||||
}
|
||||
int bucket = map(sum);
|
||||
for (int j = 0; j < NUM_BUCKETS; j++) {
|
||||
int temp;
|
||||
if (j == bucket) {
|
||||
temp = res.result[j] + 1;
|
||||
}
|
||||
else {
|
||||
temp = res.result[j];
|
||||
}
|
||||
res.result[j] = temp;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -1,19 +1,61 @@
|
||||
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b)
|
||||
#define LEN 32
|
||||
#define NUM_REVIEWERS 1
|
||||
#define NUM_RATINGS 1
|
||||
#define INTERVALS 2
|
||||
#define NUM_BUCKETS (INTERVALS * 5) - 1
|
||||
#define TOTAL_REV (NUM_REVIEWERS * NUM_RATINGS)
|
||||
|
||||
|
||||
int main(__attribute__((private(0))) int reviews[TOTAL_REV], __attribute__((private(1))) int offset)
|
||||
{
|
||||
int result[10];
|
||||
for(int i = 0; i < 10; i++){
|
||||
result[i] = 0;
|
||||
}
|
||||
int result[NUM_BUCKETS];
|
||||
|
||||
for(int i = 0; i < 10; i++){
|
||||
result[i] += 1;
|
||||
}
|
||||
for (int i = 0; i < NUM_REVIEWERS; i++) {
|
||||
int sum = 0;
|
||||
for (int j = 0; j < NUM_RATINGS; j++) {
|
||||
sum = sum + reviews[i*NUM_RATINGS + j];
|
||||
}
|
||||
int bucket = sum;
|
||||
for (int j = 0; j < NUM_BUCKETS; j++) {
|
||||
int temp;
|
||||
if (j == bucket) {
|
||||
temp = result[j] + 1;
|
||||
}
|
||||
else {
|
||||
temp = result[j];
|
||||
}
|
||||
result[j] = temp;
|
||||
}
|
||||
}
|
||||
int sum_all = offset;
|
||||
for(int i = 0; i < NUM_BUCKETS; i++){
|
||||
sum_all += result[i];
|
||||
}
|
||||
return sum_all;
|
||||
}
|
||||
|
||||
int res = 0;
|
||||
|
||||
for(int i = 0; i < 10; i++){
|
||||
res += result[i];
|
||||
}
|
||||
// int f(int a) {
|
||||
// return a + 1;
|
||||
// }
|
||||
|
||||
return res + a + b;
|
||||
}
|
||||
// int main( __attribute__((private(0))) int a, __attribute__((private(1))) int b)
|
||||
// {
|
||||
// // base input
|
||||
// int c = f(a);
|
||||
// int d = f(b);
|
||||
|
||||
// // add input
|
||||
// int e = a + c;
|
||||
// int g = b + d;
|
||||
// int h = f(e);
|
||||
// int i = f(g);
|
||||
|
||||
// // multiply input
|
||||
// int j = a * h;
|
||||
// int k = b * i;
|
||||
// int l = f(j);
|
||||
// int m = f(k);
|
||||
|
||||
// return l + m;
|
||||
// }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) {
|
||||
int index = a + b;
|
||||
int arr[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
return arr[index];
|
||||
int arr[10] = {0, a, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
return arr[index] + arr[1];
|
||||
}
|
||||
@@ -252,14 +252,14 @@ fn main() {
|
||||
Opt::Sha,
|
||||
Opt::ConstantFold(Box::new(ignore.clone())),
|
||||
Opt::Flatten,
|
||||
// The function call abstraction creates tuples
|
||||
Opt::Tuple,
|
||||
// // The function call abstraction creates tuples
|
||||
// Opt::Tuple,
|
||||
// Opt::Obliv,
|
||||
// The obliv elim pass produces more tuples, that must be eliminated
|
||||
// // The obliv elim pass produces more tuples, that must be eliminated
|
||||
// Opt::Tuple,
|
||||
// Opt::LinearScan,
|
||||
// The linear scan pass produces more tuples, that must be eliminated
|
||||
// Opt::Tuple,
|
||||
Opt::Tuple,
|
||||
Opt::ConstantFold(Box::new(ignore.clone())),
|
||||
// Inline Function Calls
|
||||
// Opt::Link,
|
||||
@@ -300,7 +300,7 @@ fn main() {
|
||||
// for (name, c) in &cs.computations {
|
||||
// println!("name: {}", name);
|
||||
// for t in c.terms_postorder() {
|
||||
// println!("t: {}", t);
|
||||
// println!("t: {}", t.op);
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
@@ -4,5 +4,5 @@ from util import run_tests
|
||||
from test_suite import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = benchmark_tests
|
||||
tests = pc_histogram_tests
|
||||
run_tests('c', tests)
|
||||
|
||||
@@ -28,7 +28,10 @@ if __name__ == "__main__":
|
||||
gauss_tests + \
|
||||
db_tests + \
|
||||
mnist_tests + \
|
||||
cryptonets_tests
|
||||
cryptonets_tests + \
|
||||
histogram_tests
|
||||
|
||||
tests = histogram_tests
|
||||
|
||||
tests = biomatch_tests
|
||||
|
||||
|
||||
3
scripts/aby_tests/test_inputs/2pc_histogram.txt
Normal file
3
scripts/aby_tests/test_inputs/2pc_histogram.txt
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -683,6 +683,21 @@ pg_tests = [
|
||||
"./scripts/aby_tests/test_inputs/playground.txt",
|
||||
]
|
||||
]
|
||||
histogram_tests = [
|
||||
[
|
||||
"histogram",
|
||||
"histogram",
|
||||
"./scripts/aby_tests/test_inputs/histogram.txt",
|
||||
]
|
||||
]
|
||||
|
||||
pc_histogram_tests = [
|
||||
[
|
||||
"2pc_histogram",
|
||||
"2pc_histogram",
|
||||
"./scripts/aby_tests/test_inputs/2pc_histogram.txt",
|
||||
]
|
||||
]
|
||||
|
||||
# ilp_benchmark_tests = [
|
||||
# [
|
||||
|
||||
@@ -66,9 +66,103 @@ function mpc_test_6 {
|
||||
cpath=$2
|
||||
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+y"
|
||||
}
|
||||
# mpc_test 2 ./examples/C/mpc/playground.c
|
||||
|
||||
# # build mpc arithmetic tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c
|
||||
# # mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c
|
||||
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c
|
||||
|
||||
# # build nary arithmetic tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c
|
||||
|
||||
# # build bitwise tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c
|
||||
|
||||
# # build boolean tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c
|
||||
|
||||
# # build nary boolean tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c
|
||||
|
||||
# # build const tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c
|
||||
|
||||
# # build if statement tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c
|
||||
|
||||
# # build shift tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c
|
||||
|
||||
# # build div tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c
|
||||
|
||||
# # build array tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c
|
||||
|
||||
# # build circ/compiler array tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c
|
||||
|
||||
# # build function tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c
|
||||
|
||||
# # build struct tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c
|
||||
|
||||
# # build matrix tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c
|
||||
|
||||
# # build ptr tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c
|
||||
|
||||
# # build misc tests
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c
|
||||
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c
|
||||
|
||||
# # build hycc benchmarks
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join2.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_merge.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist.c
|
||||
# mpc_test 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets.c
|
||||
|
||||
# # build OPA benchmarks
|
||||
mpc_test_4 2 ./examples/C/mpc/benchmarks/histogram/2pc_histogram.c
|
||||
|
||||
|
||||
mpc_test_3 2 ./examples/C/mpc/playground.c
|
||||
# mpc_test_3 2 ./examples/C/mpc/playground.c
|
||||
# mpc_test_4 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch_.c
|
||||
# mpc_test_4 2 ./examples/C/mpc/benchmarks/db/db_join2.c
|
||||
# mpc_test_4 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::ir::term::*;
|
||||
|
||||
/// Binarize cache.
|
||||
#[derive(Default)]
|
||||
|
||||
struct Binarizer;
|
||||
|
||||
impl RewritePass for Binarizer {
|
||||
@@ -26,13 +25,13 @@ impl RewritePass for Binarizer {
|
||||
|acc, x| term![orig.op.clone(); x.clone(), acc],
|
||||
))
|
||||
}
|
||||
},
|
||||
_ => None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Binarize (expand) n-ary terms.
|
||||
/// Binarize (expand) n-ary terms.
|
||||
pub fn binarize(c: &mut Computation) {
|
||||
let mut pass = Binarizer;
|
||||
pass.traverse(c);
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
//! Call Site Similarity
|
||||
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// Determine if call sites are similar based on input and output arguments to the call site
|
||||
pub fn call_site_similarity(fs: &mut Functions) {
|
||||
// Return a TermMap of (call) --> id for which calls are similar
|
||||
// Maybe return a vector of vector of terms
|
||||
|
||||
// Map of Vec<input: Vec<Term>, output: Vec<Term>> --> Vec<Call Term>
|
||||
|
||||
// For each call site, (input: Vec<Term>, output: Vec<Term>) -> Term (call)
|
||||
let mut call_sites: TermMap<(Vec<Term>, Vec<Term>)> = TermMap::new();
|
||||
|
||||
for (name, comp) in fs.computations {
|
||||
// Post order traversal through each computation
|
||||
|
||||
for t in comp.terms_postorder() {
|
||||
match t.op {
|
||||
Op::Call(name, arg_names, arg_sorts, ret_sorts) => {
|
||||
let input: Vec<Term> = t.cs;
|
||||
let output: Vec<Term> = Vec::new();
|
||||
call_sites.insert(t.clone(), (input, output));
|
||||
}
|
||||
_ => {
|
||||
// see if the call term was used as an argument in another term
|
||||
for c in t.cs {
|
||||
if call_sites.contains_key(&c) {
|
||||
call_sites.get(c).1.push(t.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each call term
|
||||
// Get a list of inputs and output terms based on mutation step size
|
||||
// Store input terms into a data structure (vec?)
|
||||
// Store output terms into a data structure (vec?)
|
||||
// Order terms by operator
|
||||
|
||||
// Create key: Vec<input: Vec<Term>, output: Vec<Term>>
|
||||
|
||||
// loop through existing call terms:
|
||||
// longest prefix matching (edit distance?)
|
||||
|
||||
// if match:
|
||||
// append to vec
|
||||
// if no match:
|
||||
// add as new entry
|
||||
}
|
||||
}
|
||||
@@ -319,6 +319,18 @@ impl Display for Op {
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Op {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.cmp(&other)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Op {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
/// Boolean n-ary operator
|
||||
pub enum BoolNaryOp {
|
||||
@@ -1283,7 +1295,7 @@ impl Value {
|
||||
*b
|
||||
} else {
|
||||
panic!("Not a bool: {}", self)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[track_caller]
|
||||
/// Get the underlying bit-vector constant, or panic!
|
||||
@@ -2040,17 +2052,6 @@ impl Computation {
|
||||
}
|
||||
}
|
||||
|
||||
// #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
|
||||
// /// A function definition.
|
||||
// pub struct FuncDef {
|
||||
// /// Name of function
|
||||
// pub name: String,
|
||||
// /// Type signature of function parameters
|
||||
// pub params: BTreeMap<String, Sort>,
|
||||
// /// Return type of function
|
||||
// pub ret_ty: Vec<Sort>,
|
||||
// }
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
/// A map of IR computations.
|
||||
pub struct Functions {
|
||||
|
||||
84
src/target/aby/call_site_similarity.rs
Normal file
84
src/target/aby/call_site_similarity.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
//! Call Site Similarity
|
||||
|
||||
use crate::ir::term::*;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Determine if call sites are similar based on input and output arguments to the call site
|
||||
pub fn call_site_similarity(fs: &Functions) -> Vec<Vec<Term>> {
|
||||
// Return a TermMap of (call) --> id for which calls are similar
|
||||
// Maybe return a vector of vector of terms
|
||||
|
||||
// Map of Vec<input: Vec<Term>, output: Vec<Term>> --> Vec<Call Term>
|
||||
|
||||
// map call Term -> (input: Vec<Term>, output: Vec<Term>)
|
||||
let mut call_term_map: TermMap<(Vec<Term>, Vec<Term>)> = TermMap::new();
|
||||
|
||||
// map field(i) Term to parent call Term
|
||||
let mut field_of_calls: TermMap<Term> = TermMap::new();
|
||||
|
||||
for (_name, comp) in &fs.computations {
|
||||
for t in comp.terms_postorder() {
|
||||
// see if the call term was used as an argument in another term
|
||||
for c in &t.cs {
|
||||
if call_term_map.contains_key(c) {
|
||||
field_of_calls.insert(t.clone(), c.clone());
|
||||
}
|
||||
if field_of_calls.contains_key(c) {
|
||||
let call_term = field_of_calls.get(c).unwrap();
|
||||
call_term_map.get_mut(call_term).unwrap().1.push(t.clone());
|
||||
}
|
||||
}
|
||||
match &t.op {
|
||||
Op::Call(..) => {
|
||||
let input: Vec<Term> = t.cs.clone();
|
||||
let output: Vec<Term> = Vec::new();
|
||||
call_term_map.insert(t.clone(), (input, output));
|
||||
}
|
||||
_ => {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each call term
|
||||
// Get a list of inputs and output terms based on mutation step size
|
||||
// Store input terms into a data structure (vec?)
|
||||
// Store output terms into a data structure (vec?)
|
||||
// Order terms by operator
|
||||
|
||||
// Create key: Vec<input: Vec<Term>, output: Vec<Term>>
|
||||
// SORT OUTPUT TERMS
|
||||
|
||||
// loop through existing call terms:
|
||||
// longest prefix matching (edit distance?)
|
||||
|
||||
// if match:
|
||||
// append to vec
|
||||
// if no match:
|
||||
// add as new entry
|
||||
}
|
||||
|
||||
// Clean input and output terms
|
||||
let mut call_sites: HashMap<(Vec<Op>, Vec<Op>), Vec<Term>> = HashMap::new();
|
||||
|
||||
for (c, (i, o)) in call_term_map {
|
||||
let input_ops = i.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
|
||||
let mut output_ops = o.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
|
||||
output_ops.sort();
|
||||
|
||||
let key = (input_ops, output_ops);
|
||||
|
||||
// longest prefix matching?
|
||||
|
||||
// edit distance?
|
||||
|
||||
if call_sites.contains_key(&key) {
|
||||
call_sites.get_mut(&key).unwrap().push(c);
|
||||
} else {
|
||||
call_sites.insert(key, vec![c]);
|
||||
}
|
||||
}
|
||||
|
||||
return call_sites.into_values().collect::<Vec<_>>();
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
//! ABY
|
||||
pub mod assignment;
|
||||
pub mod call_site_similarity;
|
||||
pub mod trans;
|
||||
pub mod utils;
|
||||
|
||||
@@ -18,7 +18,6 @@ use std::fmt;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "lp")]
|
||||
use crate::target::graph::trans::*;
|
||||
@@ -29,6 +28,8 @@ use super::assignment::assign_arithmetic_and_boolean;
|
||||
use super::assignment::assign_arithmetic_and_yao;
|
||||
use super::assignment::assign_greedy;
|
||||
|
||||
// use super::call_site_similarity::call_site_similarity;
|
||||
|
||||
const PUBLIC: u8 = 2;
|
||||
const WRITE_SIZE: usize = 65536;
|
||||
|
||||
@@ -201,11 +202,11 @@ struct ToABY<'a> {
|
||||
curr_comp: String,
|
||||
// Input mapping
|
||||
inputs: Vec<Term>,
|
||||
// Term cache
|
||||
cache: TermMap<EmbeddedTerm>,
|
||||
// Term to share id
|
||||
term_to_shares: TermMap<Vec<i32>>,
|
||||
share_cnt: i32,
|
||||
// Cache
|
||||
cache: HashMap<(Op, Vec<i32>), Vec<i32>>,
|
||||
// Outputs
|
||||
bytecode_input: Vec<String>,
|
||||
bytecode_output: Vec<String>,
|
||||
@@ -221,7 +222,6 @@ impl Drop for ToABY<'_> {
|
||||
// drop everything that uses a Term
|
||||
// drop(take(&mut self.md));
|
||||
self.inputs.clear();
|
||||
self.cache.clear();
|
||||
self.term_to_shares.clear();
|
||||
// self.s_map.clear();
|
||||
// clean up
|
||||
@@ -238,9 +238,9 @@ impl<'a> ToABY<'a> {
|
||||
lang: lang.to_string(),
|
||||
curr_comp: "".to_string(),
|
||||
inputs: Vec::new(),
|
||||
cache: TermMap::new(),
|
||||
term_to_shares: TermMap::new(),
|
||||
share_cnt: 0,
|
||||
cache: HashMap::new(),
|
||||
bytecode_input: Vec::new(),
|
||||
bytecode_output: Vec::new(),
|
||||
const_output: Vec::new(),
|
||||
@@ -314,20 +314,6 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_var_name_from_term(t: &Term) -> String {
|
||||
match &t.op {
|
||||
Op::Var(name, _) => ToABY::get_var_name(name),
|
||||
_ => panic!("Term {} is not of type Var", t),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_sharing_map(&mut self, name: &str) -> SharingMap {
|
||||
match self.s_map.get(name) {
|
||||
Some(s) => s.clone(),
|
||||
None => panic!("Unknown sharing map for function: {}", name),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_share(&mut self, t: &Term, s: i32) {
|
||||
if !self.written_const_set.contains(&s){
|
||||
let s_map = self.s_map.get(&self.curr_comp).unwrap();
|
||||
@@ -525,38 +511,13 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn embed_eq(&mut self, t: Term, a_term: Term, b_term: Term) {
|
||||
let s = self.get_share(&t);
|
||||
fn embed_eq(&mut self, t: &Term) {
|
||||
let s = self.get_share(t);
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let b = self.get_share(&t.cs[1]);
|
||||
let op = "EQ";
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
match check(&a_term) {
|
||||
Sort::Bool => {
|
||||
self.check_bool(&a_term);
|
||||
self.check_bool(&b_term);
|
||||
self.cache.insert(t, EmbeddedTerm::Bool);
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
self.check_bv(&a_term);
|
||||
self.check_bv(&b_term);
|
||||
self.cache.insert(t, EmbeddedTerm::Bool);
|
||||
}
|
||||
e => panic!("Unimplemented sort for Eq: {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Given term `t`, type-check `t` is of type Bool
|
||||
fn check_bool(&self, t: &Term) {
|
||||
match self
|
||||
.cache
|
||||
.get(t)
|
||||
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
|
||||
{
|
||||
EmbeddedTerm::Bool => (),
|
||||
_ => panic!("Non-bool for {:?}", t),
|
||||
}
|
||||
}
|
||||
|
||||
fn embed_bool(&mut self, t: Term) {
|
||||
@@ -565,7 +526,7 @@ impl<'a> ToABY<'a> {
|
||||
Op::Var(name, Sort::Bool) => {
|
||||
let md = self.get_md();
|
||||
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
|
||||
let term_name = ToABY::get_var_name_from_term(&t);
|
||||
let term_name = ToABY::get_var_name(&name);
|
||||
let vis = self.unwrap_vis(name);
|
||||
let s = self.get_share(&t);
|
||||
let op = "IN";
|
||||
@@ -580,46 +541,31 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
self.inputs.push(t.clone());
|
||||
}
|
||||
|
||||
if !self.cache.contains_key(&t) {
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
}
|
||||
Op::Const(Value::Bool(b)) => {
|
||||
let op = "CONS_bool";
|
||||
let line = format!("1 1 {} {} {}\n", *b as i32, s, op);
|
||||
self.const_output.push(line);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
Op::Eq => {
|
||||
self.embed_eq(t.clone(), t.cs[0].clone(), t.cs[1].clone());
|
||||
self.embed_eq(&t);
|
||||
}
|
||||
Op::Ite => {
|
||||
let op = "MUX";
|
||||
|
||||
self.check_bool(&t.cs[0]);
|
||||
self.check_bool(&t.cs[1]);
|
||||
self.check_bool(&t.cs[2]);
|
||||
|
||||
let sel = self.get_share(&t.cs[0]);
|
||||
let a = self.get_share(&t.cs[1]);
|
||||
let b = self.get_share(&t.cs[2]);
|
||||
|
||||
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
Op::Not => {
|
||||
let op = "NOT";
|
||||
|
||||
self.check_bool(&t.cs[0]);
|
||||
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let line = format!("1 1 {} {} {}\n", a, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
Op::BoolNaryOp(o) => {
|
||||
if t.cs.len() == 1 {
|
||||
@@ -627,7 +573,6 @@ impl<'a> ToABY<'a> {
|
||||
// If t.cs len is 1, just output that term
|
||||
// This is to bypass adding an AND gate with a single conditional term
|
||||
// Refer to pub fn condition() in src/circify/mod.rs
|
||||
self.check_bool(&t.cs[0]);
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
match o {
|
||||
BoolNaryOp::And => self.term_to_shares.insert(t.clone(), vec![a]),
|
||||
@@ -635,11 +580,7 @@ impl<'a> ToABY<'a> {
|
||||
unimplemented!("Single operand boolean operation");
|
||||
}
|
||||
};
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
} else {
|
||||
self.check_bool(&t.cs[0]);
|
||||
self.check_bool(&t.cs[1]);
|
||||
|
||||
let op = match o {
|
||||
BoolNaryOp::Or => "OR",
|
||||
BoolNaryOp::And => "AND",
|
||||
@@ -650,8 +591,6 @@ impl<'a> ToABY<'a> {
|
||||
let b = self.get_share(&t.cs[1]);
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
}
|
||||
Op::BvBinPred(o) => {
|
||||
@@ -663,38 +602,21 @@ impl<'a> ToABY<'a> {
|
||||
_ => panic!("Non-field in bool BvBinPred: {}", o),
|
||||
};
|
||||
|
||||
self.check_bv(&t.cs[0]);
|
||||
self.check_bv(&t.cs[1]);
|
||||
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let b = self.get_share(&t.cs[1]);
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool);
|
||||
}
|
||||
_ => panic!("Non-field in embed_bool: {}", t),
|
||||
}
|
||||
}
|
||||
|
||||
/// Given term `t`, type-check `t` is of type Bv
|
||||
fn check_bv(&self, t: &Term) {
|
||||
match self
|
||||
.cache
|
||||
.get(t)
|
||||
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
|
||||
{
|
||||
EmbeddedTerm::Bv => (),
|
||||
_ => panic!("Non-bv for {:?}", t),
|
||||
}
|
||||
}
|
||||
|
||||
fn embed_bv(&mut self, t: Term) {
|
||||
match &t.op {
|
||||
Op::Var(name, Sort::BitVector(_)) => {
|
||||
let md = self.get_md();
|
||||
if !self.inputs.contains(&t) && md.input_vis.contains_key(name) {
|
||||
let term_name = ToABY::get_var_name_from_term(&t);
|
||||
let term_name = ToABY::get_var_name(&name);
|
||||
let vis = self.unwrap_vis(name);
|
||||
let s = self.get_share(&t);
|
||||
let op = "IN";
|
||||
@@ -709,10 +631,6 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
self.inputs.push(t.clone());
|
||||
}
|
||||
|
||||
if !self.cache.contains_key(&t) {
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
}
|
||||
Op::Const(Value::BitVector(b)) => {
|
||||
let s = self.get_share(&t);
|
||||
@@ -725,27 +643,26 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
self.const_output.push(line);
|
||||
}
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
// self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
Op::Ite => {
|
||||
let s = self.get_share(&t);
|
||||
let op = "MUX";
|
||||
|
||||
self.check_bool(&t.cs[0]);
|
||||
self.check_bv(&t.cs[1]);
|
||||
self.check_bv(&t.cs[2]);
|
||||
|
||||
let sel = self.get_share(&t.cs[0]);
|
||||
let a = self.get_share(&t.cs[1]);
|
||||
let b = self.get_share(&t.cs[2]);
|
||||
|
||||
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
let key = (t.op.clone(), vec![sel, a, b]);
|
||||
if self.cache.contains_key(&key) {
|
||||
let s = self.cache.get(&key).unwrap().clone();
|
||||
self.term_to_shares.insert(t.clone(), s);
|
||||
} else {
|
||||
let s = self.get_shares(&t);
|
||||
self.cache.insert(key, s.clone());
|
||||
let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op);
|
||||
self.bytecode_output.push(line);
|
||||
};
|
||||
}
|
||||
Op::BvNaryOp(o) => {
|
||||
let s = self.get_share(&t);
|
||||
let op = match o {
|
||||
BvNaryOp::Xor => "XOR",
|
||||
BvNaryOp::Or => "OR",
|
||||
@@ -753,20 +670,21 @@ impl<'a> ToABY<'a> {
|
||||
BvNaryOp::Add => "ADD",
|
||||
BvNaryOp::Mul => "MUL",
|
||||
};
|
||||
|
||||
self.check_bv(&t.cs[0]);
|
||||
self.check_bv(&t.cs[1]);
|
||||
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let b = self.get_share(&t.cs[1]);
|
||||
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
let key = (t.op.clone(), vec![a, b]);
|
||||
if self.cache.contains_key(&key) {
|
||||
let s = self.cache.get(&key).unwrap().clone();
|
||||
self.term_to_shares.insert(t.clone(), s);
|
||||
} else {
|
||||
let s = self.get_shares(&t);
|
||||
self.cache.insert(key, s.clone());
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op);
|
||||
self.bytecode_output.push(line);
|
||||
};
|
||||
}
|
||||
Op::BvBinOp(o) => {
|
||||
let s = self.get_share(&t);
|
||||
let op = match o {
|
||||
BvBinOp::Sub => "SUB",
|
||||
BvBinOp::Udiv => "DIV",
|
||||
@@ -778,30 +696,37 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
match o {
|
||||
BvBinOp::Sub | BvBinOp::Udiv | BvBinOp::Urem => {
|
||||
self.check_bv(&t.cs[0]);
|
||||
self.check_bv(&t.cs[1]);
|
||||
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let b = self.get_share(&t.cs[1]);
|
||||
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
let key = (t.op.clone(), vec![a, b]);
|
||||
if self.cache.contains_key(&key) {
|
||||
let s = self.cache.get(&key).unwrap().clone();
|
||||
self.term_to_shares.insert(t, s);
|
||||
} else {
|
||||
let s = self.get_shares(&t);
|
||||
self.cache.insert(key, s.clone());
|
||||
let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op);
|
||||
self.bytecode_output.push(line);
|
||||
};
|
||||
}
|
||||
BvBinOp::Shl | BvBinOp::Lshr => {
|
||||
self.check_bv(&t.cs[0]);
|
||||
self.check_bv(&t.cs[1]);
|
||||
|
||||
let a = self.get_share(&t.cs[0]);
|
||||
let const_shift_amount_term = fold(&t.cs[1], &[]);
|
||||
let const_shift_amount =
|
||||
const_shift_amount_term.as_bv_opt().unwrap().uint();
|
||||
|
||||
let line = format!("2 1 {} {} {} {}\n", a, const_shift_amount, s, op);
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
let key = (t.op.clone(), vec![a, const_shift_amount.to_i32().unwrap()]);
|
||||
if self.cache.contains_key(&key) {
|
||||
let s = self.cache.get(&key).unwrap().clone();
|
||||
self.term_to_shares.insert(t, s);
|
||||
} else {
|
||||
let s = self.get_shares(&t);
|
||||
self.cache.insert(key, s.clone());
|
||||
let line =
|
||||
format!("2 1 {} {} {} {}\n", a, const_shift_amount, s[0], op);
|
||||
self.bytecode_output.push(line);
|
||||
};
|
||||
}
|
||||
_ => panic!("Binop not supported: {}", o),
|
||||
};
|
||||
@@ -811,7 +736,6 @@ impl<'a> ToABY<'a> {
|
||||
let shares = self.get_shares(&t.cs[0]);
|
||||
assert!(*i < shares.len());
|
||||
self.term_to_shares.insert(t.clone(), vec![shares[*i]]);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
Op::Select => {
|
||||
assert!(t.cs.len() == 2);
|
||||
@@ -828,7 +752,6 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
self.term_to_shares
|
||||
.insert(t.clone(), vec![array_shares[idx]]);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
} else {
|
||||
let op = "SELECT";
|
||||
let num_inputs = array_shares.len() + 1;
|
||||
@@ -844,7 +767,6 @@ impl<'a> ToABY<'a> {
|
||||
);
|
||||
self.bytecode_output.push(line);
|
||||
self.term_to_shares.insert(t.clone(), vec![output]);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
}
|
||||
_ => panic!("Non-field in embed_bv: {:?}", t),
|
||||
@@ -852,42 +774,51 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
|
||||
fn embed_scalar(&mut self, t: Term) {
|
||||
let now = Instant::now();
|
||||
match &t.op {
|
||||
Op::Const(Value::Array(arr)) => {
|
||||
let shares = self.get_shares(&t);
|
||||
assert!(shares.len() == arr.size);
|
||||
// let shares = self.get_shares(&t);
|
||||
// assert!(shares.len() == arr.size);
|
||||
|
||||
for (i, s) in shares.iter().enumerate() {
|
||||
let mut shares: Vec<i32> = Vec::new();
|
||||
|
||||
for i in 0..arr.size {
|
||||
// TODO: sort of index might not be a 32-bit bitvector
|
||||
let idx = Value::BitVector(BitVector::new(Integer::from(i), 32));
|
||||
let v = match arr.map.get(&idx) {
|
||||
Some(c) => c,
|
||||
|
||||
None => &*arr.default,
|
||||
};
|
||||
|
||||
match v {
|
||||
Value::BitVector(b) => {
|
||||
if !self.written_const_set.contains(s){
|
||||
self.written_const_set.insert(*s);
|
||||
let op = "CONS_bv";
|
||||
let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op);
|
||||
if b.as_sint() == 99{
|
||||
println!("GOtcha2: {}", t);
|
||||
// TODO: sort of value might not be a 32-bit bitvector
|
||||
let v_term = leaf_term(Op::Const(v.clone()));
|
||||
if self.term_to_shares.contains_key(&v_term) {
|
||||
// existing const
|
||||
let s = self.get_share(&v_term);
|
||||
shares.push(s);
|
||||
} else {
|
||||
// new const
|
||||
let s = self.get_share(&v_term);
|
||||
match v {
|
||||
Value::BitVector(b) => {
|
||||
if !self.written_const_set.contains(&s){
|
||||
self.written_const_set.insert(s);
|
||||
let op = "CONS_bv";
|
||||
let line = format!("1 1 {} {} {}\n", b.as_sint(), s, op);
|
||||
if b.as_sint() == 99{
|
||||
println!("GOtcha2: {}", t);
|
||||
}
|
||||
self.const_output.push(line);
|
||||
}
|
||||
self.const_output.push(line);
|
||||
// self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
_ => todo!(),
|
||||
}
|
||||
_ => todo!(),
|
||||
shares.push(s);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
num_const_arr += 1;
|
||||
dur_const_arr += now.elapsed();
|
||||
};
|
||||
assert!(shares.len() == arr.size);
|
||||
self.term_to_shares.insert(t.clone(), shares);
|
||||
}
|
||||
Op::Const(Value::Tuple(tup)) => {
|
||||
let shares = self.get_shares(&t);
|
||||
@@ -904,16 +835,10 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
self.const_output.push(line);
|
||||
}
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bv);
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
num_const_tuple += 1;
|
||||
dur_const_tuple += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Ite => {
|
||||
let op = "MUX";
|
||||
@@ -942,13 +867,6 @@ impl<'a> ToABY<'a> {
|
||||
);
|
||||
|
||||
self.bytecode_output.push(line);
|
||||
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Array);
|
||||
|
||||
unsafe {
|
||||
num_ite += 1;
|
||||
dur_ite += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Store => {
|
||||
assert!(t.cs.len() == 3);
|
||||
@@ -962,7 +880,6 @@ impl<'a> ToABY<'a> {
|
||||
array_shares[idx] = value_share;
|
||||
|
||||
self.term_to_shares.insert(t.clone(), array_shares.clone());
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Array);
|
||||
} else {
|
||||
let op = "STORE";
|
||||
let num_inputs = array_shares.len() + 2;
|
||||
@@ -982,11 +899,6 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
self.bytecode_output.push(line);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
num_store += 1;
|
||||
dur_store += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Field(i) => {
|
||||
assert!(t.cs.len() == 1);
|
||||
@@ -1014,12 +926,6 @@ impl<'a> ToABY<'a> {
|
||||
let field_shares = &shares[offset..offset + len];
|
||||
|
||||
self.term_to_shares.insert(t.clone(), field_shares.to_vec());
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Array);
|
||||
|
||||
unsafe {
|
||||
num_field += 1;
|
||||
dur_field += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Update(i) => {
|
||||
assert!(t.cs.len() == 2);
|
||||
@@ -1034,12 +940,6 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
// store shares
|
||||
self.term_to_shares.insert(t.clone(), tuple_shares);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
|
||||
|
||||
unsafe {
|
||||
num_update += 1;
|
||||
dur_update += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Tuple => {
|
||||
let mut shares: Vec<i32> = Vec::new();
|
||||
@@ -1047,12 +947,6 @@ impl<'a> ToABY<'a> {
|
||||
shares.append(&mut self.get_shares(c));
|
||||
}
|
||||
self.term_to_shares.insert(t.clone(), shares);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
|
||||
|
||||
unsafe {
|
||||
num_tuple += 1;
|
||||
dur_tuple += now.elapsed();
|
||||
};
|
||||
}
|
||||
Op::Call(name, _arg_names, arg_sorts, ret_sorts) => {
|
||||
let shares = self.get_shares(&t);
|
||||
@@ -1093,12 +987,6 @@ impl<'a> ToABY<'a> {
|
||||
op
|
||||
);
|
||||
self.bytecode_output.push(line);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Tuple);
|
||||
|
||||
unsafe {
|
||||
num_call += 1;
|
||||
dur_call += now.elapsed();
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
panic!("Non-field in embed_scalar: {}", t.op)
|
||||
@@ -1144,88 +1032,25 @@ impl<'a> ToABY<'a> {
|
||||
let mut write_time: std::time::Duration = std::time::Duration::new(0, 0);
|
||||
|
||||
for c in PostOrderIter_v2::new(t) {
|
||||
if self.cache.contains_key(&c) {
|
||||
if self.term_to_shares.contains_key(&c) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let b_now = Instant::now(); // check for tuples are long
|
||||
match check(&c) {
|
||||
Sort::Bool => {
|
||||
let now = Instant::now();
|
||||
self.embed_bool(c);
|
||||
num_bool += 1;
|
||||
dur_bool += now.elapsed();
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
let now = Instant::now();
|
||||
self.embed_bv(c);
|
||||
num_bv += 1;
|
||||
dur_bv += now.elapsed();
|
||||
}
|
||||
Sort::Array(..) | Sort::Tuple(_) => {
|
||||
let now = Instant::now();
|
||||
self.embed_scalar(c);
|
||||
num_scalar += 1;
|
||||
dur_scalar += now.elapsed();
|
||||
}
|
||||
e => panic!("Unsupported sort in embed: {:?}", e),
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
self.write_bytecode_output(false);
|
||||
self.write_const_output(false);
|
||||
self.write_share_output(false);
|
||||
write_time += now.elapsed();
|
||||
}
|
||||
|
||||
println!("bool: {}, bv: {}, scalar: {}", num_bool, num_bv, num_scalar);
|
||||
println!(
|
||||
"times: bool: {:?}, bv: {:?}, scalar: {:?}",
|
||||
dur_bool, dur_bv, dur_scalar
|
||||
);
|
||||
println!("write time: {:?}", write_time);
|
||||
|
||||
if num_bool > 0 && num_bv > 0 && num_scalar > 0 {
|
||||
println!(
|
||||
"norm_times: bool: {:?}, bv: {:?}, scalar: {:?}\n",
|
||||
dur_bool / num_bool,
|
||||
dur_bv / num_bv,
|
||||
dur_scalar / num_scalar
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
println!("================================");
|
||||
println!("const_arr: {}, const_tuple: {}, ite: {}, store: {}, field: {}, update: {}, tuple: {}, call: {}", num_const_arr, num_const_tuple, num_ite, num_store, num_field, num_update, num_tuple, num_call);
|
||||
println!("times: const_arr: {:?}, const_tuple: {:?}, ite: {:?}, store: {:?}, field: {:?}, update: {:?}, tuple: {:?}, call: {:?}", dur_const_arr, dur_const_tuple, dur_ite, dur_store, dur_field, dur_update, dur_tuple, dur_call);
|
||||
if num_const_arr > 0 {
|
||||
println!("norm_const_arr: {:?}", dur_const_arr / num_const_arr as u32);
|
||||
}
|
||||
if num_const_tuple > 0 {
|
||||
println!(
|
||||
"norm_const_tuple: {:?}",
|
||||
dur_const_tuple / num_const_tuple as u32
|
||||
);
|
||||
}
|
||||
if num_ite > 0 {
|
||||
println!("norm_ite: {:?}", dur_ite / num_ite as u32);
|
||||
}
|
||||
if num_store > 0 {
|
||||
println!("norm_store: {:?}", dur_store / num_store as u32);
|
||||
}
|
||||
if num_field > 0 {
|
||||
println!("norm_field: {:?}", dur_field / num_field as u32);
|
||||
}
|
||||
if num_update > 0 {
|
||||
println!("norm_update: {:?}", dur_update / num_update as u32);
|
||||
}
|
||||
if num_tuple > 0 {
|
||||
println!("norm_tuple: {:?}", dur_tuple / num_tuple as u32);
|
||||
}
|
||||
if num_call > 0 {
|
||||
println!("norm_call: {:?}", dur_call / num_call as u32);
|
||||
}
|
||||
println!("================================\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1246,7 +1071,6 @@ impl<'a> ToABY<'a> {
|
||||
|
||||
for (name, comp) in computations.iter() {
|
||||
let mut outputs: Vec<String> = Vec::new();
|
||||
let mut now = Instant::now();
|
||||
|
||||
// set current computation
|
||||
self.curr_comp = name.to_string();
|
||||
@@ -1274,10 +1098,6 @@ impl<'a> ToABY<'a> {
|
||||
}
|
||||
self.bytecode_output.append(&mut outputs);
|
||||
|
||||
println!("Time: lowering {}: {:?}", name, now.elapsed());
|
||||
|
||||
now = Instant::now();
|
||||
|
||||
// reorder inputs
|
||||
let mut bytecode_input_map: HashMap<String, String> = HashMap::new();
|
||||
for line in &self.bytecode_input {
|
||||
@@ -1304,9 +1124,6 @@ impl<'a> ToABY<'a> {
|
||||
.filter(|x| !x.is_empty())
|
||||
.collect::<Vec<String>>();
|
||||
self.bytecode_input = inputs;
|
||||
println!("Time: reordering inputs {}: {:?}", name, now.elapsed());
|
||||
|
||||
now = Instant::now();
|
||||
|
||||
// write input bytecode
|
||||
let bytecode_path =
|
||||
@@ -1322,8 +1139,6 @@ impl<'a> ToABY<'a> {
|
||||
);
|
||||
write_lines(&bytecode_output_path, &self.bytecode_output);
|
||||
|
||||
println!("Time: writing {}: {:?}", name, now.elapsed());
|
||||
|
||||
// combine input and output bytecode files into a single file
|
||||
let mut bytecode = fs::OpenOptions::new()
|
||||
.append(true)
|
||||
@@ -1347,7 +1162,6 @@ impl<'a> ToABY<'a> {
|
||||
self.bytecode_input.clear();
|
||||
self.bytecode_output.clear();
|
||||
self.inputs.clear();
|
||||
self.cache.clear();
|
||||
}
|
||||
|
||||
// write remaining const variables
|
||||
@@ -1356,15 +1170,6 @@ impl<'a> ToABY<'a> {
|
||||
// write remaining shares
|
||||
self.write_share_output(true);
|
||||
}
|
||||
|
||||
fn convert(&mut self) {
|
||||
let mut now = Instant::now();
|
||||
// self.map_to_shares();
|
||||
// println!("Time: map terms to shares: {:?}", now.elapsed());
|
||||
now = Instant::now();
|
||||
self.lower();
|
||||
println!("Time: lowering: {:?}", now.elapsed());
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this (IR) `ir` to ABY.
|
||||
@@ -1388,19 +1193,19 @@ pub fn to_aby(
|
||||
"gglp" => {
|
||||
let (fs, s_map) = inline_all_and_assign_glp(&ir, cm);
|
||||
let mut converter = ToABY::new(fs, s_map, path, lang);
|
||||
converter.convert();
|
||||
converter.lower();
|
||||
}
|
||||
#[cfg(feature = "lp")]
|
||||
"lp+mut" => {
|
||||
let (fs, s_map) = partition_with_mut(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
|
||||
let mut converter = ToABY::new(fs, s_map, path, lang);
|
||||
converter.convert();
|
||||
converter.lower();
|
||||
}
|
||||
// #[cfg(feature = "lp")]
|
||||
// "mlp+mut" => {
|
||||
// let (fs, s_map) = mlp_with_mut(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
|
||||
// let mut converter = ToABY::new(fs, s_map, path, lang);
|
||||
// converter.convert();
|
||||
// converter.lower();
|
||||
// }
|
||||
_ =>{
|
||||
// Protocal Assignments
|
||||
@@ -1424,7 +1229,7 @@ pub fn to_aby(
|
||||
s_map.insert(name.to_string(), assignments);
|
||||
}
|
||||
let mut converter = ToABY::new(ir, s_map, path, lang);
|
||||
converter.convert();
|
||||
converter.lower();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user