css but buggy

This commit is contained in:
Clive2312
2022-08-09 05:14:18 +00:00
parent 58c08d6851
commit 80296cc86e
27 changed files with 1297 additions and 564 deletions

View File

@@ -2,8 +2,8 @@
* Example on how to merge two data sets and to perform various analyses
*/
#define LEN_A 50
#define LEN_B 50
#define LEN_A 5
#define LEN_B 5
#define ATT_A 2 //Number of attributes
#define ATT_B 2

View File

@@ -9,7 +9,7 @@
#include "db.h"
void merge(DT *OUTPUT_db, DT *a, DT *b, unsigned len_a, unsigned len_b) {
void merge(DT *OUTPUT_db, DT *a, DT *b) {
// memcpy(OUTPUT_db, a, len_a * sizeof(DT));
for (int i = 0; i < LEN_A; i++) {
OUTPUT_db[i] = a[i];
@@ -58,7 +58,7 @@ Output main(__attribute__((private(0))) int a[LEN_A*ATT_A], __attribute__((priva
DT db[LEN];
// merge databases
merge(db, INPUT_A.db, INPUT_B.db, LEN_A, LEN_B);
merge(db, INPUT_A.db, INPUT_B.db);
// compute? histogram, correlation or
res.joined = LEN;

View File

@@ -1,44 +1,10 @@
#define N 64
#define K 4 // currently fixed, do not change
#define INNER 16
#define OUTER (N/64)
int match_fix(int x1, int x2,int x3, int x4, int y1, int y2, int y3, int y4) {
int r = 0;
int i;
int t1 = (x1-y1);
int t2 = (x2-y2);
int t3 = (x3-y3);
int t4 = (x4-y4);
r = t1*t1 + t2*t2 + t3*t3 + t4*t4;
return r;
void foo(int* x){
x[0] += 1;
x[1] += 1;
}
int min(int *data, int len) {
int best = data[0];
for (int i = 0; i < N; i++){
if (data[i] < best){
best = data[i];
}
}
return best;
}
void match_decomposed(int *db, int *OUTPUT_matches, int len, int *sample) {
for(int i = 0; i < N; i++) {
OUTPUT_matches[i] = match_fix(db[i*K], db[i*K+1], db[i*K+2], db[i*K+3], sample[0], sample[1], sample[2], sample[3]);
}
}
int main( __attribute__((private(0))) int db[1024], __attribute__((private(1))) int sample[4])
{
//int matches[4];
int matches[N];
match_decomposed(db, matches, N, sample);
// Compute minimum
int best_match = min(matches, N);
return best_match;
int main(__attribute__((private(0))) int a[2], __attribute__((private(1))) int b[2]) {
foo(a);
foo(b);
return a[0] + b[0];
}

View File

@@ -110,8 +110,8 @@ enum Backend {
cost_model: String,
#[structopt(long, default_value = "lp", name = "selection_scheme")]
selection_scheme: String,
#[structopt(long, default_value = "8", name = "num_parts")]
num_parts: usize,
#[structopt(long, default_value = "4000", name = "part_size")]
part_size: usize,
#[structopt(long, default_value = "4", name = "mut_level")]
mut_level: usize,
#[structopt(long, default_value = "1", name = "mut_step_size")]
@@ -310,6 +310,8 @@ fn main() {
// }
// }
// todo!("hello");
now = Instant::now();
match options.backend {
#[cfg(feature = "r1cs")]
@@ -350,11 +352,11 @@ fn main() {
Backend::Mpc {
cost_model,
selection_scheme,
num_parts,
part_size,
mut_level,
mut_step_size,
graph_type,
imbalance
imbalance,
} => {
println!("Converting to aby");
let lang_str = match language {
@@ -370,7 +372,7 @@ fn main() {
&lang_str,
&cost_model,
&selection_scheme,
&num_parts,
&part_size,
&mut_level,
&mut_step_size,
&graph_type,

View File

@@ -4,5 +4,9 @@ from util import run_tests
from test_suite import *
if __name__ == "__main__":
tests = cryptonets_tests
tests = mnist_tests
# benchmark_tests \
# + db_tests \
# + mnist_tests \
# + cryptonets_tests
run_tests('c', tests)

View File

@@ -1,3 +1,3 @@
a 0 1 2 3 4 5 6 7 8 9
b 0 1 2 3 4 5 6 7 8 9
res 10 10 5
a 0 1
b 0 1
res 1 1 5

View File

@@ -597,21 +597,21 @@ gauss_tests = [
]
db_tests = [
[
"db join",
"db_join",
"./scripts/aby_tests/test_inputs/db_join.txt",
],
# [
# "db join",
# "db_join",
# "./scripts/aby_tests/test_inputs/db_join_50.txt",
# ],
# [
# "db join 2",
# "db_join2",
# "./scripts/aby_tests/test_inputs/join2.txt",
# ],
# [
# "db merge",
# "db_merge",
# "./scripts/aby_tests/test_inputs/merge.txt",
# ],
[
"db merge",
"db_merge",
"./scripts/aby_tests/test_inputs/merge.txt",
],
]
benchmark_tests = [
@@ -625,16 +625,16 @@ benchmark_tests = [
"2pc_biomatch_",
"./scripts/aby_tests/test_inputs/biomatch_benchmark_1.txt",
],
# [
# "kmeans - 1",
# "2pc_kmeans_",
# "./scripts/aby_tests/test_inputs/kmeans.txt",
# ],
# [
# "gauss",
# "2pc_gauss_inline",
# "./scripts/aby_tests/test_inputs/gauss.txt",
# ],
[
"kmeans - 1",
"2pc_kmeans_",
"./scripts/aby_tests/test_inputs/kmeans.txt",
],
[
"gauss",
"2pc_gauss_inline",
"./scripts/aby_tests/test_inputs/gauss.txt",
],
# [
# "kmeans - 1",
# "2pc_kmeans",
@@ -662,16 +662,16 @@ mnist_tests = [
"mnist_16",
"./scripts/aby_tests/test_inputs/mnist_16.txt",
],
[
"mnist decomp main 16",
"mnist_decomp_main_16",
"./scripts/aby_tests/test_inputs/mnist_decomp_main_16.txt",
],
[
"mnist decomp convolution",
"mnist_decomp_convolution",
"./scripts/aby_tests/test_inputs/mnist_decomp_convolution.txt",
]
# [
# "mnist decomp main 16",
# "mnist_decomp_main_16",
# "./scripts/aby_tests/test_inputs/mnist_decomp_main_16.txt",
# ],
# [
# "mnist decomp convolution",
# "mnist_decomp_convolution",
# "./scripts/aby_tests/test_inputs/mnist_decomp_convolution.txt",
# ]
]
cryptonets_tests = [

View File

@@ -21,7 +21,13 @@ esac
function mpc_test {
parties=$1
cpath=$2
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "smart_glp"
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "smart_lp" --part-size 3000 --mut-level 4 --mut-step-size 1 --graph-type 0
}
function mpc_test_css {
parties=$1
cpath=$2
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "css" --part-size 3000 --mut-level 4 --mut-step-size 1 --graph-type 0
}
function mpc_test_2 {
@@ -40,13 +46,13 @@ function mpc_test_3 {
function mpc_test_4 {
parties=$1
cpath=$2
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "lp+mut" --num-parts 12 --mut-level 4 --mut-step-size 1 --graph-type 0
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "lp+mut" --mut-level 4 --mut-step-size 1 --graph-type 0
}
function mpc_test_5 {
parties=$1
cpath=$2
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "lp+mut" --num-parts 48 --mut-level 4 --mut-step-size 1 --graph-type 0
RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "lp+mut" --mut-level 4 --mut-step-size 1 --graph-type 0
}
function mpc_test_bool {
@@ -101,55 +107,55 @@ function mpc_test_9 {
# mpc_test_9 2 ./examples/C/mpc/benchmarks/db/db_join.c
# build div tests
mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c
# # build div tests
# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c
# build array tests
mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c
mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c
mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c
mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_3.c
# # build array tests
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index.c
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_2.c
# mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_index_3.c
# build circ/compiler array tests
mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c
mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c
mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c
mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c
mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c
# # build circ/compiler array tests
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_1.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_2.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c
# mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c
# build function tests
mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c
# # build function tests
# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c
# build struct tests
mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c
mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c
mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c
# # build struct tests
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/ret_struct.c
# build matrix tests
mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c
mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c
mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c
# # build matrix tests
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c
# build ptr tests
mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c
mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c
# # build ptr tests
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c
# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_arith.c
# build misc tests
mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c
mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c
# # build misc tests
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c
# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c
# build hycc benchmarks
mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c
mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c
mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch_.c
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c
# mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
# mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join2.c
# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_merge.c
# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist_16.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_main_16.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_convolution.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets_16.c
# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_main_16.c
# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_convolution.c
# mpc_test 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets_16.c
# build OPA benchmarks
# mpc_test_2 2 ./examples/C/mpc/benchmarks/histogram/histogram.c
@@ -168,3 +174,20 @@ mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_8.c
# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_9.c
# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench.c
# mpc_test_css 2 ./examples/C/mpc/playground.c
# build hycc benchmarks
# mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch_.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/biomatch/biomatch.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss_inline.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/db/db_join2.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/db/db_merge.c
mpc_test 2 ./examples/C/mpc/benchmarks/mnist/mnist_16.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_main_16.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/mnist/mnist_decomp_convolution.c
# mpc_test_css 2 ./examples/C/mpc/benchmarks/cryptonets/cryptonets_16.c

View File

@@ -36,14 +36,17 @@ pub fn link_one(arg_names: &Vec<String>, arg_values: Vec<Term>, callee: &Computa
for (n, v) in arg_names.into_iter().zip(arg_values) {
let ssa_names = callee.metadata.input_ssa_name_from_nice_name(n);
// println!("{:?}", ssa_names);
if ssa_names.len() == 1{
if ssa_names.len() == 1 {
let s = callee.metadata.input_sort(&ssa_names[0].0).clone();
sub_map.insert(leaf_term(Op::Var(ssa_names[0].0.clone(), s)), v);
} else{
} else {
for (s_name, index) in ssa_names {
let s = callee.metadata.input_sort(&s_name).clone();
sub_map.insert(leaf_term(Op::Var(s_name, s)), term![Op::Select; v.clone(), bv_lit(index, 32)]);
}
sub_map.insert(
leaf_term(Op::Var(s_name, s)),
term![Op::Select; v.clone(), bv_lit(index, 32)],
);
}
}
}
term(
@@ -308,21 +311,15 @@ mod test {
cache.insert(t.clone(), 1);
}
for t in comp.outputs.iter() {
let get_children = || -> Vec<Term> {
t.cs
.iter()
.cloned()
.collect()
};
if cache.contains_key(&term(t.op.clone(), get_children())){
let get_children = || -> Vec<Term> { t.cs.iter().cloned().collect() };
if cache.contains_key(&term(t.op.clone(), get_children())) {
println!("Got you1!!!!!");
}
}
link_all_function_calls(&mut fs);
let c = fs.get_comp("main").unwrap().clone();
for t in c.outputs.iter() {
if cache.contains_key(t){
if cache.contains_key(t) {
println!("Got you2!!!!!");
}
}

View File

@@ -133,7 +133,8 @@ mod test {
#[test]
fn select_ite_stores() {
let before = text::parse_computation(b"
let before = text::parse_computation(
b"
(computation
(metadata () () ())
(let ((z (#a (bv 4) #x0 6 ())))

View File

@@ -43,7 +43,12 @@ fn array_to_tuple(t: &Term) -> Term {
.map(|c| array_to_tuple(&c))
.collect(),
),
Sort::Tuple(..) => term(Op::Tuple, extras::tuple_elements(t).map(|c| array_to_tuple(&c)).collect()),
Sort::Tuple(..) => term(
Op::Tuple,
extras::tuple_elements(t)
.map(|c| array_to_tuple(&c))
.collect(),
),
_ => t.clone(),
}
}

View File

@@ -1789,7 +1789,11 @@ impl ComputationMetadata {
2 => (var_name[0].to_string(), 0),
3.. => {
let l = var_name.len();
(var_name[0..l - 2].to_vec().join("_").to_string(), var_name[l - 1].parse::<usize>().unwrap())
// println!("var: {:?}", var_name);
(
var_name[0..l - 2].to_vec().join("_").to_string(),
var_name[l - 1].parse::<usize>().unwrap(),
)
}
_ => {
panic!("Invalid variable name: {:?}", var_name);
@@ -1800,7 +1804,7 @@ impl ComputationMetadata {
ssa_names.push((k.to_string(), index));
}
}
if ssa_names.is_empty(){
if ssa_names.is_empty() {
println!("ssa-keys: {:?}", self.input_vis.keys());
panic!("ssa-name not found for nice name: {}", input_name);
}

View File

@@ -1,9 +1,9 @@
//! Defines a textual serialization format for [Term]s.
//!
//! Includes a parser ([parse_term]) and serializer ([serialize_term]) for [Term]s.
//!
//!
//! Includes a parser ([parse_value_map]) and serializer ([serialize_value_map]) for value maps.
//!
//!
//! Includes a parser ([parse_computation]) and serializer ([serialize_computation]) for [Computation]s.
//!
//!
@@ -752,7 +752,6 @@ pub fn serialize_functions(fns: &Functions) -> String {
out
}
#[cfg(test)]
mod test {
use super::*;

View File

@@ -1,10 +1,12 @@
use rug::Integer;
use fxhash::FxHashSet;
use fxhash::FxHashMap;
use fxhash::FxHashSet;
use crate::ir::term::*;
use std::collections::HashMap;
use std::collections::HashSet;
use std::time::Instant;
/// A post order iterater that skip the const index of select/store
@@ -31,17 +33,17 @@ impl std::iter::Iterator for PostOrderIterV2 {
if self.visited.contains(t) {
self.stack.pop();
} else if !children_pushed {
if let Op::Select = t.op{
if let Op::Select = t.op {
if let Op::Const(Value::BitVector(_)) = &t.cs[1].op {
self.stack.last_mut().unwrap().0 = true;
let last = self.stack.last().unwrap().1.clone();
self.stack.push((false, last.cs[0].clone()));
continue;
}
} else if let Op::Store = t.op{
} else if let Op::Store = t.op {
if let Op::Const(Value::BitVector(_)) = &t.cs[1].op {
self.stack.last_mut().unwrap().0 = true;
let last = self.stack.last().unwrap().1.clone();
let last = self.stack.last().unwrap().1.clone();
self.stack.push((false, last.cs[0].clone()));
self.stack.push((false, last.cs[2].clone()));
continue;
@@ -82,7 +84,7 @@ fn get_sort_len(s: &Sort) -> usize {
#[derive(Clone)]
/// A structure that maps the actual terms inside of array and tuple
pub struct DefUsesSubGraph{
pub struct DefUsesSubGraph {
/// List of terms in subgraph
pub nodes: TermSet,
/// Adjacency list of edges in subgraph
@@ -96,7 +98,7 @@ pub struct DefUsesSubGraph{
pub def_uses: FxHashMap<Term, Vec<Term>>,
}
impl DefUsesSubGraph{
impl DefUsesSubGraph {
/// default constructor
pub fn new() -> Self {
Self {
@@ -105,7 +107,7 @@ impl DefUsesSubGraph{
outs: TermSet::new(),
ins: TermSet::new(),
def_use: FxHashSet::default(),
def_uses: FxHashMap::default(),
def_uses: FxHashMap::default(),
}
}
@@ -143,49 +145,58 @@ impl DefUsesSubGraph{
}
for (d, u) in self.def_use.iter() {
self.def_uses.entry(d.clone()).or_insert_with(Vec::new).push(u.clone());
self.def_uses
.entry(d.clone())
.or_insert_with(Vec::new)
.push(u.clone());
}
}
}
/// Extend current dug to outer n level
pub fn extend_dusg(dusg: &DefUsesSubGraph, dug: &DefUsesGraph, n: usize) -> DefUsesSubGraph{
pub fn extend_dusg(dusg: &DefUsesSubGraph, dug: &DefUsesGraph, n: usize) -> DefUsesSubGraph {
let mut old_g: DefUsesSubGraph = dusg.clone();
let mut new_g: DefUsesSubGraph = DefUsesSubGraph::new();
for _ in 0..n{
for t in old_g.nodes.iter(){
for _ in 0..n {
for t in old_g.nodes.iter() {
new_g.insert_node(t);
for u in dug.def_uses.get(t).unwrap().iter(){
for u in dug.def_uses.get(t).unwrap().iter() {
new_g.insert_node(u);
}
for d in dug.use_defs.get(t).unwrap().iter(){
for d in dug.use_defs.get(t).unwrap().iter() {
new_g.insert_node(d);
}
}
old_g = new_g;
new_g = DefUsesSubGraph::new();
}
old_g.insert_edges(dug);
old_g
}
/// Def Use Graph for a computation
#[derive(Clone)]
pub struct DefUsesGraph {
pub term_to_terms: TermMap<Vec<Term>>,
pub term_to_terms: TermMap<Vec<(Term, usize)>>,
// pub term_to_terms_idx: TermMap<Vec<(Term, usize)>>,
pub def_use: FxHashSet<(Term, Term)>,
pub def_uses: FxHashMap<Term, FxHashSet<Term>>,
pub use_defs: FxHashMap<Term, FxHashSet<Term>>,
pub const_terms: TermSet,
pub good_terms: TermSet,
pub call_args: TermMap<Vec<FxHashSet<Op>>>,
pub call_rets: TermMap<Vec<FxHashSet<Op>>>,
pub call_args: TermMap<Vec<FxHashSet<usize>>>,
pub call_rets: TermMap<Vec<FxHashSet<usize>>>,
pub call_args_terms: TermMap<Vec<Vec<Term>>>,
pub call_rets_terms: TermMap<Vec<Vec<Term>>>,
pub ret_good_terms: Vec<Term>,
}
impl DefUsesGraph {
pub fn new(c: &Computation) -> Self{
pub fn new(c: &Computation) -> Self {
let mut now = Instant::now();
let mut dug = Self {
term_to_terms: TermMap::new(),
// term_to_terms_idx: TermMap::new(),
def_use: FxHashSet::default(),
def_uses: FxHashMap::default(),
use_defs: FxHashMap::default(),
@@ -193,6 +204,9 @@ impl DefUsesGraph {
good_terms: TermSet::new(),
call_args: TermMap::new(),
call_rets: TermMap::new(),
call_args_terms: TermMap::new(),
call_rets_terms: TermMap::new(),
ret_good_terms: Vec::new(),
};
dug.construct_def_use(c);
dug.construct_mapping();
@@ -200,10 +214,11 @@ impl DefUsesGraph {
dug
}
pub fn for_call_site(c: &Computation) -> Self{
pub fn for_call_site(c: &Computation) -> Self {
let mut now = Instant::now();
let mut dug = Self {
term_to_terms: TermMap::new(),
// term_to_terms_idx: TermMap::new(),
def_use: FxHashSet::default(),
def_uses: FxHashMap::default(),
use_defs: FxHashMap::default(),
@@ -211,28 +226,32 @@ impl DefUsesGraph {
good_terms: TermSet::new(),
call_args: TermMap::new(),
call_rets: TermMap::new(),
call_args_terms: TermMap::new(),
call_rets_terms: TermMap::new(),
ret_good_terms: Vec::new(),
};
dug.construct_def_use(c);
// moved this after insert context
dug.construct_mapping();
println!("Time: Def Use Graph: {:?}", now.elapsed());
dug
}
fn construct_def_use(&mut self, c: &Computation){
fn construct_def_use(&mut self, c: &Computation) {
for out in c.outputs.iter() {
for t in PostOrderIterV2::new(out.clone()) {
match &t.op{
match &t.op {
Op::Const(Value::Tuple(tup)) => {
let mut terms: Vec<Term> = Vec::new();
let mut terms: Vec<(Term, usize)> = Vec::new();
for val in tup.iter() {
terms.push(leaf_term(Op::Const(val.clone())));
terms.push((leaf_term(Op::Const(val.clone())), 0));
self.const_terms.insert(leaf_term(Op::Const(val.clone())));
self.add_term(&leaf_term(Op::Const(val.clone())));
}
self.term_to_terms.insert(t.clone(), terms);
}
Op::Tuple => {
let mut terms: Vec<Term> = Vec::new();
let mut terms: Vec<(Term, usize)> = Vec::new();
for c in t.cs.iter() {
terms.extend(self.term_to_terms.get(&c).unwrap().clone());
}
@@ -240,7 +259,7 @@ impl DefUsesGraph {
}
Op::Field(i) => {
let tuple_terms = self.term_to_terms.get(&t.cs[0]).unwrap().clone();
let tuple_sort = check(&t.cs[0]);
let (offset, len) = match tuple_sort {
Sort::Tuple(t) => {
@@ -252,7 +271,7 @@ impl DefUsesGraph {
}
// find len
let len = get_sort_len(&t[*i]);
(offset, len)
}
_ => panic!("Field op on non-tuple"),
@@ -268,22 +287,22 @@ impl DefUsesGraph {
self.term_to_terms.insert(t.clone(), tuple_terms);
}
Op::Const(Value::Array(arr)) => {
let mut terms: Vec<Term> = Vec::new();
let mut terms: Vec<(Term, usize)> = Vec::new();
let sort = check(&t);
if let Sort::Array(_, _, n) = sort{
if let Sort::Array(_, _, n) = sort {
// println!("Create a {} size array.", n);
let n = n as i32;
for i in 0..n{
for i in 0..n {
let idx = Value::BitVector(BitVector::new(Integer::from(i), 32));
let v = match arr.map.get(&idx) {
Some(c) => c,
None => &*arr.default,
};
terms.push(leaf_term(Op::Const(v.clone())));
terms.push((leaf_term(Op::Const(v.clone())), 0));
self.const_terms.insert(leaf_term(Op::Const(v.clone())));
self.add_term(&leaf_term(Op::Const(v.clone())));
}
} else{
} else {
todo!("Const array sort not array????")
}
self.term_to_terms.insert(t.clone(), terms);
@@ -297,12 +316,12 @@ impl DefUsesGraph {
// println!("Store the {} value on a {} size array.",idx , array_terms.len());
array_terms[idx] = value_terms[0].clone();
self.term_to_terms.insert(t.clone(), array_terms);
} else{
for idx in 0..array_terms.len(){
self.def_use.insert((array_terms[idx].clone(), t.clone()));
array_terms[idx] = t.clone();
} else {
for idx in 0..array_terms.len() {
self.def_use.insert((array_terms[idx].0.clone(), t.clone()));
array_terms[idx] = (t.clone(), 0);
}
self.def_use.insert((value_terms[0].clone(), t.clone()));
self.def_use.insert((value_terms[0].0.clone(), t.clone()));
self.term_to_terms.insert(t.clone(), array_terms);
self.add_term(&t);
}
@@ -312,88 +331,118 @@ impl DefUsesGraph {
if let Op::Const(Value::BitVector(bv)) = &t.cs[1].op {
// constant indexing
let idx = bv.uint().to_usize().unwrap().clone();
if array_terms.len() == 1 && idx == 1 {
println!("dad op: {:?}", t.cs[0].op);
println!("grandpa op: {:?}", t.cs[0].cs[0].op);
self.term_to_terms
.insert(t.clone(), vec![array_terms[idx].clone()]);
} else {
for idx in 0..array_terms.len() {
self.def_use.insert((array_terms[idx].0.clone(), t.clone()));
}
self.term_to_terms.insert(t.clone(), vec![array_terms[idx].clone()]);
} else{
for idx in 0..array_terms.len(){
self.def_use.insert((array_terms[idx].clone(), t.clone()));
}
self.term_to_terms.insert(t.clone(), vec![t.clone()]);
self.term_to_terms.insert(t.clone(), vec![(t.clone(), 0)]);
self.add_term(&t);
}
}
Op::Call(_, _, _, ret_sorts) => {
// Use call term itself as the placeholder
// Call term will be ignore by the ilp solver later
let mut ret_terms: Vec<Term> = Vec::new();
let mut ret_terms: Vec<(Term, usize)> = Vec::new();
let num_rets: usize = ret_sorts.iter().map(|ret| get_sort_len(ret)).sum();
let mut args: Vec<FxHashSet<Op>> = Vec::new();
let mut rets: Vec<FxHashSet<Op>> = Vec::new();
for c in t.cs.iter(){
let mut args: Vec<FxHashSet<usize>> = Vec::new();
let mut rets: Vec<FxHashSet<usize>> = Vec::new();
let mut args_t: Vec<Vec<Term>> = Vec::new();
let mut rets_t: Vec<Vec<Term>> = Vec::new();
for c in t.cs.iter() {
let arg_terms = self.term_to_terms.get(c).unwrap();
let mut arg_set: FxHashSet<Op> = FxHashSet::default();
for arg in arg_terms.iter(){
arg_set.insert(arg.op.clone());
let mut arg_set: FxHashSet<usize> = FxHashSet::default();
let mut arg_term: Vec<Term> = Vec::new();
for arg in arg_terms.iter() {
arg_set.insert(get_op_id(&arg.0.op));
arg_term.push(arg.0.clone());
}
args_t.push(arg_term);
args.push(arg_set);
}
for _ in 0..num_rets{
for idx in 0..num_rets {
rets.push(FxHashSet::default());
ret_terms.push(t.clone());
ret_terms.push((t.clone(), idx));
rets_t.push(Vec::new());
}
self.term_to_terms.insert(t.clone(), ret_terms);
self.call_args.insert(t.clone(), args);
self.call_rets.insert(t.clone(), rets);
self.call_args_terms.insert(t.clone(), args_t);
self.call_rets_terms.insert(t.clone(), rets_t);
}
Op::Ite =>{
if let Op::Store = t.cs[1].op{
Op::Ite => {
if let Op::Store = t.cs[1].op {
// assert_eq!(t.cs[2].op, Op::Store);
let cond_terms = self.term_to_terms.get(&t.cs[0]).unwrap().clone();
assert_eq!(cond_terms.len(), 1);
self.def_use.insert((cond_terms[0].clone(), t.clone()));
self.def_use.insert((cond_terms[0].0.clone(), t.clone()));
// true branch
let mut t_terms = self.term_to_terms.get(&t.cs[1]).unwrap().clone();
// false branch
let f_terms = self.term_to_terms.get(&t.cs[2]).unwrap().clone();
assert_eq!(t_terms.len(), f_terms.len());
for idx in 0..t_terms.len(){
self.def_use.insert((t_terms[idx].clone(), t.clone()));
self.def_use.insert((f_terms[idx].clone(), t.clone()));
t_terms[idx] = t.clone();
for idx in 0..t_terms.len() {
self.def_use.insert((t_terms[idx].0.clone(), t.clone()));
self.def_use.insert((f_terms[idx].0.clone(), t.clone()));
t_terms[idx] = (t.clone(), 0);
}
self.term_to_terms.insert(t.clone(), t_terms);
} else{
for c in t.cs.iter(){
if let Op::Call(..) = t.op{
} else {
for c in t.cs.iter() {
if let Op::Call(..) = t.op {
continue;
} else{
} else {
let terms = self.term_to_terms.get(c).unwrap();
assert_eq!(terms.len(), 1);
self.def_use.insert((terms[0].clone(), t.clone()));
if let Op::Call(..) = terms[0].0.op {
// insert op to ret set
let rets = self.call_rets.get_mut(&terms[0].0).unwrap();
rets.get_mut(terms[0].1).unwrap().insert(get_op_id(&t.op));
// insert term to ret terms
let rets_t =
self.call_rets_terms.get_mut(&terms[0].0).unwrap();
rets_t.get_mut(terms[0].1).unwrap().push(t.clone());
} else {
self.def_use.insert((terms[0].0.clone(), t.clone()));
}
}
}
self.term_to_terms.insert(t.clone(), vec![t.clone()]);
self.term_to_terms.insert(t.clone(), vec![(t.clone(), 0)]);
}
self.add_term(&t);
}
_ =>{
for c in t.cs.iter(){
if let Op::Call(..) = t.op{
_ => {
for c in t.cs.iter() {
if let Op::Call(..) = c.op {
continue;
} else{
let terms = self.term_to_terms.get(c).unwrap();
} else {
let terms = self.term_to_terms.get(c).unwrap().clone();
assert_eq!(terms.len(), 1);
self.def_use.insert((terms[0].clone(), t.clone()));
if let Op::Call(..) = terms[0].0.op {
// insert op to ret set
let rets = self.call_rets.get_mut(&terms[0].0).unwrap();
rets.get_mut(terms[0].1).unwrap().insert(get_op_id(&t.op));
// insert term to ret terms
let rets_t = self.call_rets_terms.get_mut(&terms[0].0).unwrap();
rets_t.get_mut(terms[0].1).unwrap().push(t.clone());
} else {
self.def_use.insert((terms[0].0.clone(), t.clone()));
}
}
}
self.term_to_terms.insert(t.clone(), vec![t.clone()]);
self.term_to_terms.insert(t.clone(), vec![(t.clone(), 0)]);
self.add_term(&t);
}
}
}
let out_terms = self.term_to_terms.get(out).unwrap().clone();
for (t, _) in out_terms.iter() {
// v.push(t.clone());
self.ret_good_terms.push(t.clone());
}
// This is for the case when out term is not a good term, we still need it.
// if !self.good_terms.contains(out){
// let this_terms = self.term_to_terms.get(out).unwrap();
@@ -405,16 +454,16 @@ impl DefUsesGraph {
}
}
fn construct_mapping(&mut self){
for (def, _use) in self.def_use.iter(){
if self.def_uses.contains_key(def){
fn construct_mapping(&mut self) {
for (def, _use) in self.def_use.iter() {
if self.def_uses.contains_key(def) {
self.def_uses.get_mut(def).unwrap().insert(_use.clone());
} else {
let mut uses: FxHashSet<Term> = FxHashSet::default();
uses.insert(_use.clone());
self.def_uses.insert(def.clone(), uses);
}
if self.use_defs.contains_key(_use){
if self.use_defs.contains_key(_use) {
self.use_defs.get_mut(_use).unwrap().insert(def.clone());
} else {
let mut defs: FxHashSet<Term> = FxHashSet::default();
@@ -424,26 +473,175 @@ impl DefUsesGraph {
}
}
fn add_term(&mut self, t: &Term){
/// Out put the call site from this function's computation
pub fn get_call_site(
&mut self,
) -> Vec<(
String,
Vec<usize>,
Vec<Vec<Term>>,
Vec<usize>,
Vec<Vec<Term>>,
Term,
)> {
let mut call_sites: Vec<(
String,
Vec<usize>,
Vec<Vec<Term>>,
Vec<usize>,
Vec<Vec<Term>>,
Term,
)> = Vec::new();
for (t, args_set) in self.call_args.iter() {
// Stupid implementation, Should fix this
if let Op::Call(fname, _, _, _) = &t.op {
let rets_set = self.call_rets.get(t).unwrap();
let mut rets: Vec<usize> = Vec::new();
let mut args: Vec<usize> = Vec::new();
for s in rets_set.iter() {
let mut v: Vec<usize> = s.clone().into_iter().collect();
v.sort();
rets.extend(v);
}
for s in args_set.iter() {
let mut v: Vec<usize> = s.clone().into_iter().collect();
v.sort();
args.extend(v);
}
let args_t = self.call_args_terms.get(t).unwrap().clone();
let rets_t = self.call_rets_terms.get(t).unwrap().clone();
call_sites.push((fname.clone(), args, args_t, rets, rets_t, t.clone()));
}
}
call_sites
}
/// insert the caller's context
pub fn insert_context(
&mut self,
arg_names: &Vec<String>,
arg_values: &Vec<Vec<Term>>,
rets: &Vec<Vec<Term>>,
caller_dug: &DefUsesGraph,
callee: &Computation,
extra_level: usize,
) {
let mut input_set: TermSet = TermSet::new();
let mut output_set: TermSet = TermSet::new();
// insert def of args
for (n, v) in arg_names.into_iter().zip(arg_values) {
let ssa_names = callee.metadata.input_ssa_name_from_nice_name(n);
for (sname, index) in ssa_names.iter() {
let s = callee.metadata.input_sort(&sname).clone();
// println!("Def: {}, Use: {}", v.get(*index).unwrap(), leaf_term(Op::Var(sname.clone(), s.clone())));
let def_t = v.get(*index).unwrap();
let use_t = leaf_term(Op::Var(sname.clone(), s));
if let Op::Call(..) = def_t.op {
continue;
}
if !self.good_terms.contains(&use_t) {
// println!("FIX: {}", use_t.op);
// This is because the function doesn't use this arg
//todo!("Fix this...");
continue;
}
self.add_term(&def_t);
self.def_use.insert((def_t.clone(), use_t));
input_set.insert(def_t.clone());
}
}
// insert use of rets
let outs = self.ret_good_terms.clone();
// for tt in rets.iter(){
// println!("[");
// for t in tt.iter(){
// println!("rets op: {}", t);
// }
// println!("]");
// }
// println!("=====");
// for tt in outs.iter(){
// println!("[");
// for t in tt.iter(){
// println!("outs op: {}", t);
// }
// println!("]");
// }
assert_eq!(outs.len(), rets.len());
for (d, uses) in outs.into_iter().zip(rets) {
for u in uses.iter() {
self.add_term(u);
self.def_use.insert((d.clone(), u.clone()));
}
}
// kind of mutation?
for i in 1..extra_level {
// insert def of def
for def in input_set.clone().iter() {
let def_defs = caller_dug.def_uses.get(def).unwrap();
for def_def in def_defs.iter() {
self.add_term(def_def);
self.def_use.insert((def_def.clone(), def.clone()));
input_set.insert(def_def.clone());
}
}
// insert use of use
for _use in output_set.clone().iter() {
let use_uses = caller_dug.def_uses.get(_use).unwrap();
for use_use in use_uses.iter() {
self.add_term(use_use);
self.def_use.insert((_use.clone(), use_use.clone()));
input_set.insert(use_use.clone());
}
}
}
self.construct_mapping();
}
fn add_term(&mut self, t: &Term) {
self.good_terms.insert(t.clone());
let defs: FxHashSet<Term> = FxHashSet::default();
let uses: FxHashSet<Term> = FxHashSet::default();
self.def_uses.insert(t.clone(), uses);
self.use_defs.insert(t.clone(), defs);
}
}
pub fn is_good_term(t: &Term) -> bool{
match t.op{
pub fn is_good_term(t: &Term) -> bool {
match t.op {
Op::Const(Value::Tuple(_))
| Op::Tuple
| Op::Field(_)
| Op::Update(_)
| Op::Const(Value::Array(_))
| Op::Store
| Op::Store
| Op::Select
| Op::Call(..) => false,
_ => true
_ => true,
}
}
}
pub fn get_op_id(op: &Op) -> usize {
match op {
Op::Var(..) => 1,
Op::Const(_) => 2,
Op::Eq => 3,
Op::Ite => 4,
Op::Not => 5,
Op::BoolNaryOp(o) => 6,
Op::BvBinPred(o) => 7,
Op::BvNaryOp(o) => 8,
Op::BvBinOp(o) => 9,
Op::Select => 10,
Op::Store => 11,
Op::Call(..) => 12,
_ => todo!("What op?"),
}
}

View File

@@ -78,7 +78,7 @@ pub fn assign_mut(c: &ComputationSubgraph, cm: &str, co: &ComputationSubgraph) -
while smap.len() == 0 {
// A hack for empty result during multi-threading
// Simply retry until get a non-empty result
if cnt > 5{
if cnt > 5 {
panic!("MT BUG: Dead loop.")
}
smap = build_ilp(c, &costs);
@@ -97,7 +97,11 @@ pub fn assign_mut(c: &ComputationSubgraph, cm: &str, co: &ComputationSubgraph) -
}
/// Uses an ILP to assign and abandon the outer assignments
pub fn assign_mut_smart(dusg: &DefUsesSubGraph, cm: &str, dusg_ref: &DefUsesSubGraph) -> SharingMap {
pub fn assign_mut_smart(
dusg: &DefUsesSubGraph,
cm: &str,
dusg_ref: &DefUsesSubGraph,
) -> SharingMap {
let base_dir = match cm {
"opa" => "opa",
"hycc" => "hycc",
@@ -115,7 +119,7 @@ pub fn assign_mut_smart(dusg: &DefUsesSubGraph, cm: &str, dusg_ref: &DefUsesSubG
while smap.len() == 0 {
// A hack for empty result during multi-threading
// Simply retry until get a non-empty result
if cnt > 5{
if cnt > 5 {
panic!("MT BUG: Dead loop.")
}
smap = build_smart_ilp(dusg.nodes.clone(), &dusg.def_use, &costs);
@@ -134,7 +138,11 @@ pub fn assign_mut_smart(dusg: &DefUsesSubGraph, cm: &str, dusg_ref: &DefUsesSubG
}
/// Uses an ILP to assign...
pub fn smart_global_assign(terms: &TermSet, def_uses: &FxHashSet<(Term,Term)>, cm: &str) -> SharingMap {
pub fn smart_global_assign(
terms: &TermSet,
def_uses: &FxHashSet<(Term, Term)>,
cm: &str,
) -> SharingMap {
let base_dir = match cm {
"opa" => "opa",
"hycc" => "hycc",
@@ -150,10 +158,16 @@ pub fn smart_global_assign(terms: &TermSet, def_uses: &FxHashSet<(Term,Term)>, c
build_smart_ilp(terms.clone(), def_uses, &costs)
}
fn build_smart_ilp(term_set: TermSet, def_uses: &FxHashSet<(Term,Term)>, costs: &CostModel) -> SharingMap {
let terms: FxHashMap<Term, usize> =
term_set.into_iter().enumerate().map(|(i, t)| (t, i)).collect();
fn build_smart_ilp(
term_set: TermSet,
def_uses: &FxHashSet<(Term, Term)>,
costs: &CostModel,
) -> SharingMap {
let terms: FxHashMap<Term, usize> = term_set
.into_iter()
.enumerate()
.map(|(i, t)| (t, i))
.collect();
let mut term_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default();
let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> =
FxHashMap::default();
@@ -230,69 +244,84 @@ fn build_smart_ilp(term_set: TermSet, def_uses: &FxHashSet<(Term,Term)>, costs:
for use_ in uses {
for from_ty in &SHARE_TYPES {
for to_ty in &SHARE_TYPES {
let ilp_version = false;
if ilp_version{
let ilp_version = true;
if ilp_version {
conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
term_vars.get(&(def.clone(), *from_ty)).map(|t_from| {
// c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1
term_vars
.get(&(use_.clone(), *to_ty))
.map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)))
term_vars.get(&(use_.clone(), *to_ty)).map(|t_to| {
ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0))
})
})
});
} else{
} else {
// hardcoding here
// a2b > y2b
// y2a > b2a
// a2y > b2y
if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean{
if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean {
let cheap_ty = ShareType::Yao;
conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
term_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
term_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
term_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic{
} else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic {
let cheap_ty = ShareType::Boolean;
conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
term_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
term_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
term_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao{
} else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao {
let cheap_ty = ShareType::Boolean;
conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
term_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
term_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
term_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else{
} else {
conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
term_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
term_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
term_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0)))
term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0),
)
})
})
})
});
@@ -908,76 +937,91 @@ fn build_comb_ilp_smart(
for use_ in uses {
for from_ty in &SHARE_TYPES {
for to_ty in &SHARE_TYPES {
let ilp_version = false;
if ilp_version{
let ilp_version = true;
if ilp_version {
e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
v_vars.get(&(def.clone(), *from_ty)).map(|t_from| {
// c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1
v_vars
.get(&(use_.clone(), *to_ty))
.map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)))
v_vars.get(&(use_.clone(), *to_ty)).map(|t_to| {
ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0))
})
})
});
} else{
} else {
// hardcoding here
// a2b > y2b
// y2a > b2a
// a2y > b2y
if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean{
if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean {
let cheap_ty = ShareType::Yao;
e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
v_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
v_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
v_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic{
} else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic {
let cheap_ty = ShareType::Boolean;
e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
v_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
v_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
v_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao{
} else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao {
let cheap_ty = ShareType::Boolean;
e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
v_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
v_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| {
v_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0 - d_ch.0)))
v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0
- 1.0
- d_to.0
- d_ch.0),
)
})
})
})
})
});
} else{
} else {
e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| {
v_vars.get(&(def.clone(), *from_ty)).map(|d_from| {
v_vars.get(&(def.clone(), *to_ty)).map(|d_to| {
v_vars
.get(&(use_.clone(), *to_ty))
.map(|u_to| ilp.new_constraint(c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0)))
v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| {
ilp.new_constraint(
c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0),
)
})
})
})
});
}
}
}
}
}
}
}
@@ -1012,7 +1056,12 @@ pub fn calculate_cost_smart(smap: &SharingMap, costs: &CostModel, dusg: &DefUses
let mut conv_cost: HashMap<(Term, ShareType), f64> = HashMap::new();
for (t, to_ty) in smap {
match &t.op {
Op::Var(..) | Op::Const(_) | Op::BvConcat | Op::BvExtract(..) | Op::BoolToBv | Op::BvBit(_) => {
Op::Var(..)
| Op::Const(_)
| Op::BvConcat
| Op::BvExtract(..)
| Op::BoolToBv
| Op::BvBit(_) => {
cost = cost + 0.0;
}
_ => {
@@ -1041,7 +1090,12 @@ pub fn calculate_cost(smap: &SharingMap, costs: &CostModel) -> f64 {
let mut conv_cost: HashMap<(Term, ShareType), f64> = HashMap::new();
for (t, to_ty) in smap {
match &t.op {
Op::Var(..) | Op::Const(_) | Op::BvConcat | Op::BvExtract(..) | Op::BoolToBv | Op::BvBit(_) => {
Op::Var(..)
| Op::Const(_)
| Op::BvConcat
| Op::BvExtract(..)
| Op::BoolToBv
| Op::BvBit(_) => {
cost = cost + 0.0;
}
_ => {
@@ -1068,7 +1122,12 @@ pub fn calculate_node_cost(smap: &SharingMap, costs: &CostModel) -> f64 {
let mut cost: f64 = 0.0;
for (t, to_ty) in smap {
match &t.op {
Op::Var(..) | Op::Const(_) | Op::BvConcat | Op::BvExtract(..) | Op::BoolToBv | Op::BvBit(_) => {
Op::Var(..)
| Op::Const(_)
| Op::BvConcat
| Op::BvExtract(..)
| Op::BoolToBv
| Op::BvBit(_) => {
cost = cost + 0.0;
}
_ => {
@@ -1116,19 +1175,11 @@ mod tests {
);
assert_eq!(
&1731.0,
c
.get(&BV_MUL)
.unwrap()
.get(&ShareType::Boolean)
.unwrap()
c.get(&BV_MUL).unwrap().get(&ShareType::Boolean).unwrap()
);
assert_eq!(
&7.0,
c
.get(&BV_XOR)
.unwrap()
.get(&ShareType::Boolean)
.unwrap()
c.get(&BV_XOR).unwrap().get(&ShareType::Boolean).unwrap()
);
}

View File

@@ -131,7 +131,7 @@ impl CostModel {
// | Op::Store
| Op::Call(..)
| Op::Const(..)=> {
todo!("Op get cost: Should not reach here.");
todo!("Op get cost: Should not reach here: {}", op);
}
Op::Field(_)
| Op::Update(..)

View File

@@ -1,81 +1,335 @@
//! Call Site Similarity
use crate::ir::term::*;
use crate::target::aby::assignment::def_uses::*;
use fxhash::{FxHashMap, FxHashSet};
use std::collections::HashMap;
use std::collections::HashSet;
/// What do we need for call site?
///
/// Call sites:
/// HashMap<(String, Vec<usize>, Vec<usize>), Vec<Term>>
/// - Each entry of {call_sites} will become a copy of function
///
/// Computations:
/// HashMap<String, HashMap<usize, Computation>>
/// - String: fname
/// - usize: version id
///
/// DefUseGraph:
/// HashMap<String, DefUseGraph>
///
/// Surrounding info:
/// Two type:
/// 1. For inner calls:
/// - Per call
/// 2. For outer calls:
/// - Per call site
///
/// args: HashMap<String, Vec<Term>>
/// rets: HashMap<String, Vec<Term>>
#[derive(Clone)]
/// A structure that stores the context and all the call terms in one call site
struct CallSite {
// Context's fname
pub caller: String,
pub callee: String,
pub arg_names: Vec<String>,
pub args: Vec<Vec<Term>>,
pub rets: Vec<Vec<Term>>,
pub calls: Vec<Term>,
pub caller_dug: DefUsesGraph,
}
impl CallSite {
pub fn new(
caller: &String,
callee: &String,
args: &Vec<Vec<Term>>,
arg_names: &Vec<String>,
rets: &Vec<Vec<Term>>,
t: &Term,
caller_dug: &DefUsesGraph,
) -> Self {
Self {
caller: caller.clone(),
callee: callee.clone(),
arg_names: arg_names.clone(),
args: args.clone(),
rets: rets.clone(),
calls: vec![t.clone()],
caller_dug: caller_dug.clone(),
}
}
pub fn insert(&mut self, t: &Term) {
self.calls.push(t.clone());
}
}
/// Determine if call sites are similar based on input and output arguments to the call site
pub fn call_site_similarity(fs: &Functions) -> Vec<Vec<Term>> {
// Return a TermMap of (call) --> id for which calls are similar
// Maybe return a vector of vector of terms
pub fn call_site_similarity(fs: &Functions) -> (Functions, HashMap<String, DefUsesGraph>) {
let mut call_sites: HashMap<(String, Vec<usize>, Vec<usize>), CallSite> = HashMap::new();
let mut func_to_cs: HashMap<String, HashMap<usize, CallSite>> = HashMap::new();
// Map of Vec<input: Vec<Term>, output: Vec<Term>> --> Vec<Call Term>
let mut dup_per_func: HashMap<String, usize> = HashMap::new();
// map call Term -> (input: Vec<Term>, output: Vec<Term>)
let mut call_term_map: TermMap<(Vec<Term>, Vec<Term>)> = TermMap::new();
// Mapping of callee-caller pair
let mut callee_caller: HashSet<(String, String)> = HashSet::new();
// Functions that have more than one call site
let mut duplicated_f: HashSet<String> = HashSet::new();
// Functions that need to be rewrote for calling to duplicated f
// If a callee is duplicated, the caller need to be rewrote
let mut rewriting_f: HashSet<String> = HashSet::new();
// map field(i) Term to parent call Term
let mut field_of_calls: TermMap<Term> = TermMap::new();
for (_name, comp) in &fs.computations {
for t in comp.terms_postorder() {
// see if the call term was used as an argument in another term
for c in &t.cs {
if call_term_map.contains_key(c) {
field_of_calls.insert(t.clone(), c.clone());
}
if field_of_calls.contains_key(c) {
let call_term = field_of_calls.get(c).unwrap();
call_term_map.get_mut(call_term).unwrap().1.push(t.clone());
}
}
match &t.op {
Op::Call(..) => {
let input: Vec<Term> = t.cs.clone();
let output: Vec<Term> = Vec::new();
call_term_map.insert(t.clone(), (input, output));
}
_ => {
// do nothing
// Iterate all the comp and retrieve call site info
for (caller, comp) in fs.computations.iter() {
let mut dug = DefUsesGraph::for_call_site(comp);
let cs: Vec<(
String,
Vec<usize>,
Vec<Vec<Term>>,
Vec<usize>,
Vec<Vec<Term>>,
Term,
)> = dug.get_call_site();
// dugs.insert(caller.clone(), dug.clone());
for (callee, args, args_t, rets, rets_t, t) in cs.iter() {
let key: (String, Vec<usize>, Vec<usize>) =
(callee.clone(), args.clone(), rets.clone());
if call_sites.contains_key(&key) {
call_sites.get_mut(&key).unwrap().insert(t);
} else {
// Use the first context
if let Op::Call(_, arg_names, _, _) = &t.op {
let cs = CallSite::new(caller, callee, args_t, arg_names, rets_t, t, &dug);
call_sites.insert(key, cs);
}
}
// recording callee-caller
callee_caller.insert((callee.clone(), caller.clone()));
}
// For each call term
// Get a list of inputs and output terms based on mutation step size
// Store input terms into a data structure (vec?)
// Store output terms into a data structure (vec?)
// Order terms by operator
// Create key: Vec<input: Vec<Term>, output: Vec<Term>>
// SORT OUTPUT TERMS
// loop through existing call terms:
// longest prefix matching (edit distance?)
// if match:
// append to vec
// if no match:
// add as new entry
dup_per_func.insert(caller.clone(), 0);
func_to_cs.insert(caller.clone(), HashMap::new());
// // HACK: for main func:
// if caller == "main"{
// new_dugs.get_mut(caller).unwrap().insert(0, dug.clone());
// }
}
// Clean input and output terms
let mut call_sites: HashMap<(Vec<Op>, Vec<Op>), Vec<Term>> = HashMap::new();
for (c, (i, o)) in call_term_map {
let input_ops = i.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
let output_ops = o.iter().map(|x| x.op.clone()).collect::<Vec<Op>>();
let key = (input_ops, output_ops);
let mut call_map: TermMap<usize> = TermMap::new();
// longest prefix matching?
// Generating duplicate set
for (key, cs) in call_sites.iter() {
let call_id: usize = dup_per_func.get(&key.0).unwrap().clone();
// edit distance?
if call_id > 0 {
// indicate this function need to be rewrote
duplicated_f.insert(key.0.clone());
}
if call_sites.contains_key(&key) {
call_sites.get_mut(&key).unwrap().push(c);
} else {
call_sites.insert(key, vec![c]);
for t in cs.calls.iter() {
call_map.insert(t.clone(), call_id);
}
dup_per_func.insert(key.0.clone(), call_id + 1);
let id_to_cs = func_to_cs.get_mut(&key.0).unwrap();
id_to_cs.insert(call_id, cs.clone());
}
// Generating rewriting set
for (callee, caller) in callee_caller.iter() {
if duplicated_f.contains(callee) {
rewriting_f.insert(caller.clone());
}
}
return call_sites.into_values().collect::<Vec<_>>();
remap(fs, &rewriting_f, &duplicated_f, &call_map, &func_to_cs)
}
/// Rewriting the call term to new call
fn rewrite_call(c: &mut Computation, call_map: &TermMap<usize>, duplicate_set: &HashSet<String>) {
let mut cache = TermMap::<Term>::new();
let mut children_added = TermSet::new();
let mut stack = Vec::new();
stack.extend(c.outputs.iter().cloned());
while let Some(top) = stack.pop() {
if !cache.contains_key(&top) {
// was it missing?
if children_added.insert(top.clone()) {
stack.push(top.clone());
stack.extend(top.cs.iter().filter(|c| !cache.contains_key(c)).cloned());
} else {
let get_children = || -> Vec<Term> {
top.cs
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect()
};
let new_t_op: Op = match &top.op {
Op::Call(name, arg_names, arg_sorts, ret_sorts) => {
let mut new_t = top.op.clone();
if duplicate_set.contains(name) {
if let Some(cid) = call_map.get(&top) {
let new_n = format_dup_call(name, cid);
let mut new_arg_names: Vec<String> = Vec::new();
for an in arg_names.iter() {
new_arg_names.push(an.replace(name, &new_n));
}
new_t = Op::Call(
new_n,
new_arg_names,
arg_sorts.clone(),
ret_sorts.clone(),
);
}
}
new_t
}
_ => top.op.clone(),
};
let new_t = term(new_t_op, get_children());
cache.insert(top.clone(), new_t);
}
}
}
c.outputs = c
.outputs
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
}
/// Rewriting the var term to new name
fn rewrite_var(c: &mut Computation, fname: &String, cid: &usize) {
let mut cache = TermMap::<Term>::new();
let mut children_added = TermSet::new();
let mut stack = Vec::new();
stack.extend(c.outputs.iter().cloned());
while let Some(top) = stack.pop() {
if !cache.contains_key(&top) {
// was it missing?
if children_added.insert(top.clone()) {
stack.push(top.clone());
stack.extend(top.cs.iter().filter(|c| !cache.contains_key(c)).cloned());
} else {
let get_children = || -> Vec<Term> {
top.cs
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect()
};
let new_t_op: Op = match &top.op {
Op::Var(name, sort) => {
let new_call_n = format_dup_call(fname, cid);
let new_var_n = name.replace(fname, &new_call_n);
Op::Var(new_var_n.clone(), sort.clone())
}
_ => top.op.clone(),
};
let new_t = term(new_t_op, get_children());
cache.insert(top.clone(), new_t);
}
}
}
c.outputs = c
.outputs
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
}
fn remap(
fs: &Functions,
rewriting_set: &HashSet<String>,
duplicate_set: &HashSet<String>,
call_map: &TermMap<usize>,
func_to_cs: &HashMap<String, HashMap<usize, CallSite>>,
) -> (Functions, HashMap<String, DefUsesGraph>) {
let mut n_fs = Functions::new();
let mut n_dugs: HashMap<String, DefUsesGraph> = HashMap::new();
for (fname, comp) in fs.computations.iter() {
let mut ncomp: Computation = comp.clone();
let id_to_cs = func_to_cs.get(fname).unwrap();
if rewriting_set.contains(fname) {
rewrite_call(&mut ncomp, call_map, duplicate_set);
}
if duplicate_set.contains(fname) {
for (cid, cs) in id_to_cs.iter() {
let new_n: String = format_dup_call(fname, cid);
let mut dup_comp: Computation = Computation{
outputs: ncomp.outputs().clone(),
metadata: rewrite_metadata(&ncomp.metadata, fname, &new_n),
precomputes: ncomp.precomputes.clone(),
};
rewrite_var(&mut dup_comp, fname, cid);
let mut dug = DefUsesGraph::new(&dup_comp);
dug.insert_context(
&cs.arg_names,
&cs.args,
&cs.rets,
&cs.caller_dug,
&dup_comp,
4,
);
n_fs.insert(new_n.clone(), dup_comp);
n_dugs.insert(new_n.clone(), dug);
}
} else {
let mut dug = DefUsesGraph::new(&ncomp);
// println!("fname {}", fname);
// Main function might not have any call site info
if let Some(cs) = id_to_cs.get(&0) {
dug.insert_context(&cs.arg_names, &cs.args, &cs.rets, &cs.caller_dug, &ncomp, 4);
}
n_fs.insert(fname.clone(), ncomp);
n_dugs.insert(fname.clone(), dug.clone());
}
}
(n_fs, n_dugs)
}
fn format_dup_call(fname: &String, cid: &usize) -> String {
format!("{}_circ_v_{}", fname, cid).clone()
}
fn rewrite_metadata(md: &ComputationMetadata, fname: &String, n_fname: &String) -> ComputationMetadata {
let mut input_vis: FxHashMap<String, (Term, Option<PartyId>)> = FxHashMap::default();
let mut computation_inputs: FxHashSet<String> = FxHashSet::default();
let mut computation_arg_names: Vec<String> = Vec::new();
for (s, tu) in md.input_vis.iter(){
let s = s.clone();
let new_s = s.replace(fname, n_fname);
input_vis.insert(new_s, tu.clone());
}
for s in md.computation_inputs.iter(){
let s = s.clone();
let new_s = s.replace(fname, n_fname);
computation_inputs.insert(new_s);
}
for s in md.computation_arg_names.iter(){
let s = s.clone();
let new_s = s.replace(fname, n_fname);
computation_arg_names.push(new_s);
}
ComputationMetadata {
party_ids: md.party_ids.clone(),
next_party_id: md.next_party_id.clone(),
input_vis,
computation_inputs,
computation_arg_names,
}
}

View File

@@ -34,7 +34,7 @@ use super::assignment::ShareType;
use std::time::Instant;
// use super::call_site_similarity::call_site_similarity;
use super::call_site_similarity::call_site_similarity;
const PUBLIC: u8 = 2;
const WRITE_SIZE: usize = 65536;
@@ -312,10 +312,10 @@ impl<'a> ToABY<'a> {
self.term_to_shares.insert(t.clone(), s);
self.share_cnt += 1;
// Write share
self.write_share(t, s);
// Write share
self.write_share(t, s);
s
s
}
}
}
@@ -921,6 +921,7 @@ impl<'a> ToABY<'a> {
get_path(self.path, &self.lang, "share_map", true);
for (name, comp) in computations.iter() {
println!("function name: {}", name);
let mut outputs: Vec<String> = Vec::new();
// set current computation
@@ -936,9 +937,13 @@ impl<'a> ToABY<'a> {
for t in comp.outputs.iter() {
self.embed(t.clone());
println!("out op: {}", t.op);
// println!("out op: {}", t.op);
let op = "OUT";
let to_share_type = self.get_term_share_type(&t);
let mut to_share_type = self.get_term_share_type(&t);
// HACK
if to_share_type == ShareType::None{
to_share_type = ShareType::Yao;
}
let share = self.get_share(&t, to_share_type);
let line = format!("1 0 {} {}\n", share, op);
outputs.push(line);
@@ -1027,7 +1032,7 @@ pub fn to_aby(
lang: &str,
cm: &str,
ss: &str,
#[allow(unused_variables)] np: &usize,
#[allow(unused_variables)] ps: &usize,
#[allow(unused_variables)] ml: &usize,
#[allow(unused_variables)] mss: &usize,
#[allow(unused_variables)] hyper: &usize,
@@ -1035,8 +1040,17 @@ pub fn to_aby(
) {
// TODO: change ILP to take in Functions instead of individual computations
// call_site_similarity(&ir);
// todo!("Hello");
match ss{
#[cfg(feature = "lp")]
"css" => {
let (fs, dugs) = call_site_similarity(&ir);
let s_map = css_partition_with_mut_smart(&fs, &dugs, cm, path, lang, ps, *hyper==1, ml, mss, imbalance);
let mut converter = ToABY::new(fs, s_map, path, lang);
converter.lower();
}
#[cfg(feature = "lp")]
"gglp" => {
let (fs, s_map) = inline_all_and_assign_glp(&ir, cm);
@@ -1045,7 +1059,7 @@ pub fn to_aby(
}
#[cfg(feature = "lp")]
"lp+mut" => {
let (fs, s_map) = partition_with_mut(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
let (fs, s_map) = partition_with_mut(&ir, cm, path, lang, ps, *hyper==1, ml, mss, imbalance);
let mut converter = ToABY::new(fs, s_map, path, lang);
converter.lower();
}
@@ -1057,7 +1071,7 @@ pub fn to_aby(
}
#[cfg(feature = "lp")]
"smart_lp" => {
let (fs, s_map) = partition_with_mut_smart(&ir, cm, path, lang, np, *hyper==1, ml, mss, imbalance);
let (fs, s_map) = partition_with_mut_smart(&ir, cm, path, lang, ps, *hyper==1, ml, mss, imbalance);
let mut converter = ToABY::new(fs, s_map, path, lang);
converter.lower();
}

View File

@@ -1,6 +1,6 @@
//! Multi-level Partitioning Implementation
//!
//!
//!
//!
use crate::ir::term::*;
@@ -20,7 +20,7 @@ pub struct CoarsenMap {
num_nodes_per_level: Vec<usize>,
}
impl CoarsenMap{
impl CoarsenMap {
fn new() -> Self {
let mut g = Self {
num_nodes: 0,
@@ -38,7 +38,7 @@ impl CoarsenMap{
/// add a node to current coarsen map
/// Nodes added by this function are not coarsened
fn add_node_to_coarsen_map(&mut self, t: &Term){
fn add_node_to_coarsen_map(&mut self, t: &Term) {
self.num_nodes_per_level[0] += 1;
self.num_nodes += 1;
let node_id = self.num_nodes_per_level[0];
@@ -47,32 +47,33 @@ impl CoarsenMap{
}
/// merge callee's coarsen map to caller's
fn merge_coarsen_map(&mut self, g: &CoarsenMap, sub_map: &TermMap<Term>){
fn merge_coarsen_map(&mut self, g: &CoarsenMap, sub_map: &TermMap<Term>) {
// extend the coarsen level if needed
if g.coarsen_level > self.coarsen_level{
for _ in self.coarsen_level..g.coarsen_level{
if g.coarsen_level > self.coarsen_level {
for _ in self.coarsen_level..g.coarsen_level {
self.num_nodes_per_level.push(self.num_nodes);
}
self.coarsen_level = g.coarsen_level;
}
// merge the map into
for (t, v) in g.coarsen_map.iter(){
// merge the map into
for (t, v) in g.coarsen_map.iter() {
let new_t = sub_map.get(t).unwrap();
let new_v: Vec<usize> = (0..v.len()).map(|i| v[i] + self.num_nodes_per_level[i]).collect();
let new_v: Vec<usize> = (0..v.len())
.map(|i| v[i] + self.num_nodes_per_level[i])
.collect();
self.coarsen_map.insert(new_t.clone(), new_v);
self.num_nodes += 1;
}
// update the number of nodes of each level
for i in 0..g.coarsen_level{
for i in 0..g.coarsen_level {
self.num_nodes_per_level[i] += g.num_nodes_per_level[i];
}
}
}
pub struct MultiLevelPartition{
pub struct MultiLevelPartition {
partitioner: Partitioner,
gwriter: GraphWriter,
fs: Functions,
@@ -85,8 +86,16 @@ pub struct MultiLevelPartition{
hyper_mode: bool,
}
impl MultiLevelPartition{
pub fn new(fs: &Functions, coarsen_threshold: usize, num_coarsen_node: usize, path: &String, time_limit: usize, imbalance: usize, hyper_mode: bool) -> Self{
impl MultiLevelPartition {
pub fn new(
fs: &Functions,
coarsen_threshold: usize,
num_coarsen_node: usize,
path: &String,
time_limit: usize,
imbalance: usize,
hyper_mode: bool,
) -> Self {
let mlp = Self {
partitioner: Partitioner::new(time_limit, imbalance, hyper_mode),
gwriter: GraphWriter::new(hyper_mode),
@@ -103,16 +112,16 @@ impl MultiLevelPartition{
}
/// muti-level coarsening
fn multilevel_coarsen(&mut self, fname: &String) -> bool{
fn multilevel_coarsen(&mut self, fname: &String) -> bool {
let mut coarsened = false;
if !self.comp_history.contains_key(fname){
if !self.comp_history.contains_key(fname) {
let c = self.fs.get_comp(fname).unwrap().clone();
let mut cnt = 0;
for t in c.terms_postorder() {
if let Op::Call(callee, ..) = &t.op {
coarsened |= self.multilevel_coarsen(callee);
cnt += self.func_comp_size.get(callee).unwrap();
} else{
} else {
cnt += 1;
}
}
@@ -121,7 +130,7 @@ impl MultiLevelPartition{
self.comp_history.insert(fname.into(), new_c);
self.graph_history.insert(fname.into(), new_g);
if cnt > self.coarsen_threshold{
if cnt > self.coarsen_threshold {
// perform coarsened
coarsened = true;
self.coarsening_by_partition(fname);
@@ -136,20 +145,23 @@ impl MultiLevelPartition{
let cs = self.comp_history.get(fname).unwrap();
let num_nodes = cm.num_nodes_per_level.get(0).unwrap().clone();
for (t, v) in cm.coarsen_map.iter(){
for (t, v) in cm.coarsen_map.iter() {
t_map.insert(t.clone(), v.get(0).unwrap().clone());
}
let mut gw: GraphWriter = GraphWriter::new(self.hyper_mode);
gw.build_from_tm(cs, &t_map, num_nodes);
let coarsen_graph_path = format!("{}.{}.coarsen{}.graph", self.path, fname, cm.coarsen_level);
let coarsen_graph_path =
format!("{}.{}.coarsen{}.graph", self.path, fname, cm.coarsen_level);
gw.write(&coarsen_graph_path);
let num_parts = num_nodes / self.num_coarsen_node;
let partition = self.partitioner.do_partition(&coarsen_graph_path, &num_parts);
let partition = self
.partitioner
.do_partition(&coarsen_graph_path, &num_parts);
cm.num_nodes_per_level.insert(0, num_parts);
let tmp = cm.coarsen_map.clone();
for t in tmp.keys(){
for t in tmp.keys() {
let mut v = cm.coarsen_map.get_mut(t).unwrap();
let tid = t_map.get(t).unwrap();
v.insert(0, partition.get(tid).unwrap().clone());
@@ -157,7 +169,7 @@ impl MultiLevelPartition{
}
/// Merge the function call inside this function and generate graph
fn merge_and_graph(&mut self, fname: &str) -> (Computation, CoarsenMap){
fn merge_and_graph(&mut self, fname: &str) -> (Computation, CoarsenMap) {
let mut cache = TermMap::<Term>::new();
let mut children_added = TermSet::new();
let mut is_arg = TermSet::new();
@@ -184,26 +196,28 @@ impl MultiLevelPartition{
.collect()
};
if let Op::Call(fn_name, arg_names, _, _) = &top.op {
let callee = self.comp_history.get(fn_name).expect("missing inlined callee");
let callee = self
.comp_history
.get(fn_name)
.expect("missing inlined callee");
let coarsened = self.graph_history.contains_key(fn_name);
let (new_t, sub_map) = link_one_sub(arg_names, get_children(), callee);
if coarsened{
if coarsened {
// coarsened function, take care of mapping
let callee_g = self.graph_history.get(fn_name).unwrap();
cache.insert(top.clone(), new_t.clone());
g.merge_coarsen_map(&callee_g, &sub_map);
g.add_node_to_coarsen_map(&new_t);
} else{
} else {
cache.insert(top.clone(), new_t.clone());
stack.push(new_t.clone());
}
} else {
let new_t = term(top.op.clone(), get_children());
// arg nodes will be handle later by call node
if !is_arg.contains(&top){
if !is_arg.contains(&top) {
g.add_node_to_coarsen_map(&new_t);
}
cache.insert(top.clone(), new_t);
@@ -219,34 +233,49 @@ impl MultiLevelPartition{
(caller, g)
}
fn multilevel_uncoarsen(&mut self, fname: &String, partition: &HashMap<usize, usize>, num_parts: usize) -> TermMap<usize>{
fn multilevel_uncoarsen(
&mut self,
fname: &String,
partition: &HashMap<usize, usize>,
num_parts: usize,
) -> TermMap<usize> {
let cm = self.graph_history.get(fname).unwrap();
let cs = self.comp_history.get(fname).unwrap();
let mut cur_part = partition.clone();
for l in 1..cm.coarsen_level {
let mut gw: GraphWriter = GraphWriter::new(self.hyper_mode);
gw.build(cs, &cm.coarsen_map, l, cm.num_nodes_per_level.get(l).unwrap().clone());
gw.build(
cs,
&cm.coarsen_map,
l,
cm.num_nodes_per_level.get(l).unwrap().clone(),
);
let part_graph_path = format!("{}.{}.part.graph", self.path, fname);
let prev_part_path = format!("{}.{}.refine_{}.part", self.path, fname, l);
gw.write(&part_graph_path);
// coarsen the partition
let mut tmp: HashMap<usize, usize> = HashMap::new();
for (t, v) in cm.coarsen_map.iter(){
for (t, v) in cm.coarsen_map.iter() {
// fix this
let prev_id = *(v.get(l-1).unwrap_or_else(|| v.last().unwrap()));
let prev_id = *(v.get(l - 1).unwrap_or_else(|| v.last().unwrap()));
let cur_id = *(v.get(l).unwrap_or_else(|| v.last().unwrap()));
tmp.insert(cur_id, cur_part.get(&prev_id).unwrap().clone());
}
cur_part = tmp;
write_partition(&prev_part_path, &cur_part);
let placeholder = format!("Path_404");
cur_part = self.partitioner.do_refinement(&part_graph_path, &prev_part_path, &placeholder, &num_parts);
cur_part = self.partitioner.do_refinement(
&part_graph_path,
&prev_part_path,
&placeholder,
&num_parts,
);
}
let mut part_result: TermMap<usize> = TermMap::new();
let finest_l = cm.coarsen_level;
for (t, v) in cm.coarsen_map.iter(){
for (t, v) in cm.coarsen_map.iter() {
let cur_id = v.get(finest_l).unwrap_or_else(|| v.last().unwrap());
part_result.insert(t.clone(), cur_part.get(cur_id).unwrap().clone());
}
@@ -254,8 +283,12 @@ impl MultiLevelPartition{
part_result
}
pub fn run(&mut self, fname: &String, path: &String, num_parts: usize) -> (Computation, TermMap<usize>){
pub fn run(
&mut self,
fname: &String,
path: &String,
num_parts: usize,
) -> (Computation, TermMap<usize>) {
// Coarsening
self.multilevel_coarsen(fname);
@@ -265,7 +298,7 @@ impl MultiLevelPartition{
let cs = self.comp_history.get(fname).unwrap();
let num_nodes = cm.num_nodes_per_level.get(0).unwrap().clone();
for (t, v) in cm.coarsen_map.iter(){
for (t, v) in cm.coarsen_map.iter() {
t_map.insert(t.clone(), v.get(0).unwrap().clone());
}
@@ -276,15 +309,19 @@ impl MultiLevelPartition{
let partition = self.partitioner.do_partition(&part_graph_path, &num_parts);
// Uncoarsening
(self.comp_history.get(fname).unwrap().clone(),self.multilevel_uncoarsen(fname, &partition, num_parts))
(
self.comp_history.get(fname).unwrap().clone(),
self.multilevel_uncoarsen(fname, &partition, num_parts),
)
}
}
/// Copy of link_one function with sub_map for coarsen node mapping
fn link_one_sub(arg_names: &Vec<String>, arg_values: Vec<Term>, callee: &Computation) -> (Term, TermMap<Term>) {
fn link_one_sub(
arg_names: &Vec<String>,
arg_values: Vec<Term>,
callee: &Computation,
) -> (Term, TermMap<Term>) {
let mut sub_map: TermMap<Term> = arg_names
.into_iter()
.zip(arg_values)
@@ -303,4 +340,4 @@ fn link_one_sub(arg_names: &Vec<String>, arg_values: Vec<Term>, callee: &Computa
.collect(),
);
(t, sub_map)
}
}

View File

@@ -1,7 +1,7 @@
//! Graph partitioning backend
// #[cfg(feature = "lp")]
pub mod trans;
pub mod utils;
pub mod mlp;
pub mod tp;
pub mod trans;
pub mod utils;

View File

@@ -1,32 +1,32 @@
//! Multi-level Partitioning Implementation
//!
//!
//!
//!
use crate::ir::term::*;
use crate::ir::opt::link::link_one;
use crate::ir::term::*;
use crate::target::aby::assignment::def_uses::*;
use crate::target::graph::utils::graph_utils::*;
use crate::target::graph::utils::part::*;
use crate::target::aby::assignment::def_uses::*;
use std::collections::HashMap;
pub struct TrivialPartition{
pub struct TrivialPartition {
partitioner: Partitioner,
gwriter: GraphWriter,
fs: Functions,
comp_history: HashMap<String, Computation>,
}
impl TrivialPartition{
pub fn new(fs: &Functions, time_limit: usize, imbalance: usize, hyper_mode: bool) -> Self{
impl TrivialPartition {
pub fn new(fs: &Functions, time_limit: usize, imbalance: usize, hyper_mode: bool) -> Self {
let mut tp = Self {
partitioner: Partitioner::new(time_limit, imbalance, hyper_mode),
gwriter: GraphWriter::new(hyper_mode),
fs: fs.clone(),
comp_history: HashMap::new(),
};
for fname in fs.computations.keys(){
for fname in fs.computations.keys() {
tp.traverse(fname);
}
tp
@@ -34,7 +34,7 @@ impl TrivialPartition{
/// traverse the comp and combine
fn traverse(&mut self, fname: &String) {
if !self.comp_history.contains_key(fname){
if !self.comp_history.contains_key(fname) {
let mut c = self.fs.get_comp(fname).unwrap().clone();
let mut cnt = 0;
for t in c.terms_postorder() {
@@ -87,7 +87,10 @@ impl TrivialPartition{
) -> Option<Term> {
if let Op::Call(fn_name, arg_names, _, _) = &orig.op {
// println!("Rewritten children: {:?}", rewritten_children());
let callee = self.comp_history.get(fn_name).expect("missing inlined callee");
let callee = self
.comp_history
.get(fn_name)
.expect("missing inlined callee");
let term = link_one(arg_names, rewritten_children(), callee);
Some(term)
} else {
@@ -95,24 +98,59 @@ impl TrivialPartition{
}
}
pub fn inline_all(&mut self, fname: &String) -> (Computation, DefUsesGraph){
pub fn inline_all(&mut self, fname: &String) -> (Computation, DefUsesGraph) {
let c = self.comp_history.get(fname).unwrap().clone();
let dug = DefUsesGraph::new(&c);
(c, dug)
}
pub fn run(&mut self, fname: &String, path: &String, num_parts: usize) -> (Computation, DefUsesGraph ,TermMap<usize>){
pub fn run(
&mut self,
fname: &String,
path: &String,
ps: usize,
) -> (Computation, DefUsesGraph, TermMap<usize>, usize) {
let mut part_map = TermMap::new();
self.traverse(fname);
let c = self.comp_history.get(fname).unwrap();
let dug = DefUsesGraph::new(&c);
let t_map = self.gwriter.build_from_dug(&dug);
self.gwriter.write(path);
let partition = self.partitioner.do_partition(path, &num_parts);
for (t, tid) in t_map.iter(){
part_map.insert(t.clone(), *partition.get(tid).unwrap());
let num_parts = dug.good_terms.len() / ps + 1;
println!("LOG: Number of Partitions: {}", num_parts);
if num_parts > 1 {
let t_map = self.gwriter.build_from_dug(&dug);
self.gwriter.write(path);
let partition = self.partitioner.do_partition(path, &num_parts);
for (t, tid) in t_map.iter() {
part_map.insert(t.clone(), *partition.get(tid).unwrap());
}
}
(self.comp_history.get(fname).unwrap().clone(), dug, part_map)
(
self.comp_history.get(fname).unwrap().clone(),
dug,
part_map,
num_parts,
)
}
}
pub fn run_from_dug(
&mut self,
fname: &String,
dug: &DefUsesGraph,
path: &String,
ps: usize,
) -> (TermMap<usize>, usize) {
let mut part_map = TermMap::new();
let c = self.fs.get_comp(fname);
let num_parts = dug.good_terms.len() / ps + 1;
println!("LOG: Number of Partitions: {}", num_parts);
if num_parts > 1 {
let t_map = self.gwriter.build_from_dug(&dug);
self.gwriter.write(path);
let partition = self.partitioner.do_partition(path, &num_parts);
for (t, tid) in t_map.iter() {
part_map.insert(t.clone(), *partition.get(tid).unwrap());
}
}
(part_map, num_parts)
}
}

View File

@@ -1,25 +1,25 @@
use crate::ir::term::*;
use crate::target::graph::tp::*;
use crate::target::graph::mlp::*;
#[cfg(feature = "lp")]
use crate::target::aby::assignment::ilp::assign;
#[cfg(feature = "lp")]
use crate::target::aby::assignment::ilp::smart_global_assign;
use crate::target::graph::mlp::*;
use crate::target::graph::tp::*;
#[cfg(feature = "lp")]
use crate::target::graph::utils::mutation::*;
use crate::target::aby::assignment::ShareType;
use crate::target::aby::assignment::SharingMap;
use crate::target::aby::assignment::def_uses::*;
use std::path::Path;
use std::collections::HashMap;
use std::path::Path;
use std::time::Instant;
use std::fs;
// Get file path to write Chaco graph to
fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String {
let filename = Path::new(&path.iter().last().unwrap().to_os_string())
@@ -30,7 +30,7 @@ fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String {
.unwrap();
let name = format!("{}_{}", filename, lang);
let mut path = format!("scripts/aby_tests/tests/{}.graph", name);
if hyper_mode{
if hyper_mode {
path = format!("scripts/aby_tests/tests/{}_hyper.graph", name);
}
if Path::new(&path).exists() {
@@ -39,11 +39,10 @@ fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String {
path
}
// #[cfg(feature = "lp")]
// /// inline all function into main
// pub fn partition_with_mut(
// fs: &Functions,
// fs: &Functions,
// cm: &str,
// path: &Path,
// lang: &str,
@@ -58,8 +57,8 @@ fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String {
// 20000000,
// 10,
// path,
// 0,
// imbalance.clone(),
// 0,
// imbalance.clone(),
// false
// );
// let main = "main";
@@ -92,27 +91,27 @@ fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String {
// s_map.insert(main.to_string(), assignment);
// let mut fs = Functions::new();
// fs.insert(main.to_string(), c);
// (fs, s_map)
// (fs, s_map)
// }
#[cfg(feature = "lp")]
/// inline all function into main
pub fn partition_with_mut(
fs: &Functions,
fs: &Functions,
cm: &str,
path: &Path,
lang: &str,
num_parts: &usize,
ps: &usize,
hyper_mode: bool,
ml: &usize,
mss: &usize,
imbalance: &usize,
) -> (Functions, HashMap<String, SharingMap>){
) -> (Functions, HashMap<String, SharingMap>) {
let mut now = Instant::now();
let mut tp = TrivialPartition::new(fs, 0, imbalance.clone(), hyper_mode);
let main = "main";
let graph_path = get_graph_path(path, lang, hyper_mode);
let (c, d, partition) = tp.run(&main.to_string(), &graph_path, *num_parts);
let (c, d, partition, num_parts) = tp.run(&main.to_string(), &graph_path, *ps);
println!("Time: Partition: {:?}", now.elapsed());
now = Instant::now();
@@ -121,11 +120,11 @@ pub fn partition_with_mut(
let mut tmp_css: HashMap<usize, ComputationSubgraph> = HashMap::new();
let mut css: HashMap<usize, ComputationSubgraph> = HashMap::new();
for part_id in 0..*num_parts{
for part_id in 0..num_parts {
tmp_css.insert(part_id, ComputationSubgraph::new());
}
for (t, part_id) in partition.iter(){
for (t, part_id) in partition.iter() {
if let Some(subgraph) = tmp_css.get_mut(&part_id) {
subgraph.insert_node(t);
} else {
@@ -133,7 +132,7 @@ pub fn partition_with_mut(
}
}
for (part_id, mut cs) in tmp_css.into_iter(){
for (part_id, mut cs) in tmp_css.into_iter() {
cs.insert_edges();
css.insert(part_id, cs.clone());
}
@@ -146,72 +145,158 @@ pub fn partition_with_mut(
s_map.insert(main.to_string(), assignment);
let mut fs = Functions::new();
fs.insert(main.to_string(), c);
(fs, s_map)
(fs, s_map)
}
#[cfg(feature = "lp")]
/// inline all function into main
pub fn partition_with_mut_smart(
fs: &Functions,
pub fn css_partition_with_mut_smart(
fs: &Functions,
dugs: &HashMap<String, DefUsesGraph>,
cm: &str,
path: &Path,
lang: &str,
num_parts: &usize,
ps: &usize,
hyper_mode: bool,
ml: &usize,
mss: &usize,
imbalance: &usize,
) -> (Functions, HashMap<String, SharingMap>){
) -> HashMap<String, SharingMap> {
let mut now = Instant::now();
let mut s_map: HashMap<String, SharingMap> = HashMap::new();
for (fname, comp) in fs.computations.iter() {
println!("Partitioning: {}", fname);
let mut tp = TrivialPartition::new(fs, 0, imbalance.clone(), hyper_mode);
let graph_path = get_graph_path(path, lang, hyper_mode);
let d = dugs.get(fname).unwrap();
let (partition, num_parts) = tp.run_from_dug(fname, d, &graph_path, *ps);
println!("Time: Partition: {:?}", now.elapsed());
let mut assignment: SharingMap;
if num_parts == 1 {
// No need to partition
now = Instant::now();
assignment = smart_global_assign(&d.good_terms, &d.def_use, cm);
println!("Time: ILP : {:?}", now.elapsed());
} else {
// Construct DefUsesSubGraph
now = Instant::now();
let mut tmp_dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
let mut dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
for part_id in 0..num_parts {
tmp_dusg.insert(part_id, DefUsesSubGraph::new());
}
for t in d.good_terms.iter() {
let part_id = partition.get(t).unwrap();
if let Some(du) = tmp_dusg.get_mut(&part_id) {
du.insert_node(t);
} else {
panic!("Subgraph not found for index: {}", num_parts);
}
}
for (part_id, mut du) in tmp_dusg.into_iter() {
du.insert_edges(&d);
dusg.insert(part_id, du.clone());
}
println!("Time: To Subgraph: {:?}", now.elapsed());
now = Instant::now();
assignment = get_share_map_with_mutation_smart(&d, cm, &dusg, &partition, ml, mss);
println!("Time: ILP : {:?}", now.elapsed());
}
// HACK: Assign sharetype to out gate
for out in comp.outputs.iter() {
if !assignment.contains_key(&out) {
let ref_t = d.term_to_terms.get(&out).unwrap().get(0).unwrap().clone().0;
println!("ref_t: op {} ", ref_t.op);
// Parent is a call term
let s_type = assignment.get(&ref_t).unwrap_or(&ShareType::Arithmetic).clone();
assignment.insert(out.clone(), s_type);
}
}
s_map.insert(fname.clone(), assignment);
}
s_map
}
#[cfg(feature = "lp")]
/// inline all function into main
pub fn partition_with_mut_smart(
fs: &Functions,
cm: &str,
path: &Path,
lang: &str,
ps: &usize,
hyper_mode: bool,
ml: &usize,
mss: &usize,
imbalance: &usize,
) -> (Functions, HashMap<String, SharingMap>) {
let mut now = Instant::now();
let mut tp = TrivialPartition::new(fs, 0, imbalance.clone(), hyper_mode);
let main = "main";
let graph_path = get_graph_path(path, lang, hyper_mode);
let (c, d, partition) = tp.run(&main.to_string(), &graph_path, *num_parts);
let (c, d, partition, num_parts) = tp.run(&main.to_string(), &graph_path, *ps);
println!("Time: Partition: {:?}", now.elapsed());
now = Instant::now();
// Construct DefUsesSubGraph
let mut tmp_dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
let mut dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
let assignment: SharingMap;
if num_parts == 1 {
// No need to partition
now = Instant::now();
assignment = smart_global_assign(&d.good_terms, &d.def_use, cm);
println!("Time: ILP : {:?}", now.elapsed());
} else {
// Construct DefUsesSubGraph
now = Instant::now();
let mut tmp_dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
let mut dusg: HashMap<usize, DefUsesSubGraph> = HashMap::new();
for part_id in 0..*num_parts{
tmp_dusg.insert(part_id, DefUsesSubGraph::new());
}
for t in d.good_terms.iter(){
let part_id = partition.get(t).unwrap();
if let Some(du) = tmp_dusg.get_mut(&part_id) {
du.insert_node(t);
} else {
panic!("Subgraph not found for index: {}", num_parts);
for part_id in 0..num_parts {
tmp_dusg.insert(part_id, DefUsesSubGraph::new());
}
for t in d.good_terms.iter() {
println!("op: {}", t.op);
let part_id = partition.get(t).unwrap();
if let Some(du) = tmp_dusg.get_mut(&part_id) {
du.insert_node(t);
} else {
panic!("Subgraph not found for index: {}", num_parts);
}
}
for (part_id, mut du) in tmp_dusg.into_iter() {
du.insert_edges(&d);
dusg.insert(part_id, du.clone());
}
println!("Time: To Subgraph: {:?}", now.elapsed());
now = Instant::now();
assignment = get_share_map_with_mutation_smart(&d, cm, &dusg, &partition, ml, mss);
println!("Time: ILP : {:?}", now.elapsed());
}
for (part_id, mut du) in tmp_dusg.into_iter(){
du.insert_edges(&d);
dusg.insert(part_id, du.clone());
}
println!("Time: To Subgraph: {:?}", now.elapsed());
now = Instant::now();
let assignment = get_share_map_with_mutation_smart(&d, cm, &dusg, &partition, ml, mss);
println!("Time: ILP : {:?}", now.elapsed());
let mut s_map: HashMap<String, SharingMap> = HashMap::new();
s_map.insert(main.to_string(), assignment);
let mut fs = Functions::new();
fs.insert(main.to_string(), c);
(fs, s_map)
(fs, s_map)
}
#[cfg(feature = "lp")]
/// inline all function into main
pub fn inline_all_and_assign_glp(
fs: &Functions,
cm: &str
) -> (Functions, HashMap<String, SharingMap>){
pub fn inline_all_and_assign_glp(
fs: &Functions,
cm: &str,
) -> (Functions, HashMap<String, SharingMap>) {
let mut tp = TrivialPartition::new(fs, 0, 0, false);
let main = "main";
let (c, dug) = tp.inline_all(&main.to_string());
@@ -232,16 +317,19 @@ pub fn inline_all_and_assign_glp(
#[cfg(feature = "lp")]
/// inline all function into main
pub fn inline_all_and_assign_smart_glp(
fs: &Functions,
cm: &str
) -> (Functions, HashMap<String, SharingMap>){
pub fn inline_all_and_assign_smart_glp(
fs: &Functions,
cm: &str,
) -> (Functions, HashMap<String, SharingMap>) {
let mut now = Instant::now();
let mut tp = TrivialPartition::new(fs, 0, 0, false);
let main = "main";
let (c, dug) = tp.inline_all(&main.to_string());
println!("Time: Inline and construction def uses: {:?}", now.elapsed());
println!(
"Time: Inline and construction def uses: {:?}",
now.elapsed()
);
now = Instant::now();
let assignment = smart_global_assign(&dug.good_terms, &dug.def_use, cm);
@@ -252,4 +340,4 @@ pub fn inline_all_and_assign_smart_glp(
let mut fs = Functions::new();
fs.insert(main.to_string(), c);
(fs, s_map)
}
}

View File

@@ -2,8 +2,8 @@
//! This input format can be found in [Jostle User Guide](https://chriswalshaw.co.uk/jostle/jostle-exe.pdf)
//!
//!
//!
//!
//!
//!
use crate::ir::term::*;
use crate::target::aby::assignment::def_uses::*;
@@ -28,7 +28,6 @@ struct Edges<T> {
vec: Vec<T>,
}
impl<T: PartialEq> Edges<T> {
fn add(&mut self, item: T) -> bool {
if !self.vec.contains(&item) {
@@ -39,13 +38,13 @@ impl<T: PartialEq> Edges<T> {
}
}
fn coarse_map_get(cm: &HashMap<Term, Vec<usize>>, t: &Term ,level: usize) -> usize{
fn coarse_map_get(cm: &HashMap<Term, Vec<usize>>, t: &Term, level: usize) -> usize {
let v = cm.get(t).unwrap();
*(v.get(level).unwrap_or_else(|| v.last().unwrap()))
}
///
pub struct GraphWriter{
///
pub struct GraphWriter {
num_nodes: usize,
num_edges: usize,
num_hyper_edges: usize,
@@ -58,7 +57,7 @@ pub struct GraphWriter{
}
impl GraphWriter {
pub fn new(hyper_mode: bool) -> Self{
pub fn new(hyper_mode: bool) -> Self {
let gw = Self {
num_nodes: 0,
num_edges: 0,
@@ -73,7 +72,13 @@ impl GraphWriter {
gw
}
pub fn build(&mut self, cs: &Computation, coarsen_map: &HashMap<Term, Vec<usize>>, level: usize, num_nodes: usize){
pub fn build(
&mut self,
cs: &Computation,
coarsen_map: &HashMap<Term, Vec<usize>>,
level: usize,
num_nodes: usize,
) {
self.num_nodes = num_nodes;
for t in cs.terms_postorder() {
match &t.op {
@@ -91,8 +96,8 @@ impl GraphWriter {
let t_id = coarse_map_get(coarsen_map, &t, level);
for cs in t.cs.iter() {
let cs_id = coarse_map_get(coarsen_map, &cs, level);
if cs_id != t_id{
if self.hyper_mode{
if cs_id != t_id {
if self.hyper_mode {
self.insert_hyper_edge(&cs_id, &t_id);
} else {
self.insert_edge(&cs_id, &t_id);
@@ -106,7 +111,7 @@ impl GraphWriter {
}
}
pub fn build_from_tm(&mut self, cs: &Computation, tm: &TermMap<usize>, num_nodes: usize){
pub fn build_from_tm(&mut self, cs: &Computation, tm: &TermMap<usize>, num_nodes: usize) {
self.num_nodes = num_nodes;
for t in cs.terms_postorder() {
match &t.op {
@@ -124,8 +129,8 @@ impl GraphWriter {
let t_id = tm.get(&t).unwrap();
for cs in t.cs.iter() {
let cs_id = tm.get(&cs).unwrap();
if cs_id != t_id{
if self.hyper_mode{
if cs_id != t_id {
if self.hyper_mode {
self.insert_hyper_edge(&cs_id, &t_id);
} else {
self.insert_edge(&cs_id, &t_id);
@@ -139,17 +144,17 @@ impl GraphWriter {
}
}
fn get_tid_or_assign(&mut self, t: &Term) -> usize{
if self.term_to_id.contains_key(t){
fn get_tid_or_assign(&mut self, t: &Term) -> usize {
if self.term_to_id.contains_key(t) {
return *(self.term_to_id.get(t).unwrap());
} else{
} else {
self.num_nodes += 1;
self.term_to_id.insert(t.clone(), self.num_nodes);
return self.num_nodes;
}
}
pub fn build_from_cs(&mut self, cs: &Computation) -> HashMap<Term, usize>{
pub fn build_from_cs(&mut self, cs: &Computation) -> HashMap<Term, usize> {
for t in cs.terms_postorder() {
match &t.op {
Op::Var(_, _) | Op::Const(_) => {
@@ -169,8 +174,8 @@ impl GraphWriter {
let t_id = self.get_tid_or_assign(&t);
for cs in t.cs.iter() {
let cs_id = self.get_tid_or_assign(&cs);
if cs_id != t_id{
if self.hyper_mode{
if cs_id != t_id {
if self.hyper_mode {
self.insert_hyper_edge(&cs_id, &t_id);
} else {
self.insert_edge(&cs_id, &t_id);
@@ -185,7 +190,7 @@ impl GraphWriter {
self.term_to_id.clone()
}
pub fn build_from_dug(&mut self, dug: &DefUsesGraph) -> HashMap<Term, usize>{
pub fn build_from_dug(&mut self, dug: &DefUsesGraph) -> HashMap<Term, usize> {
for t in dug.good_terms.iter() {
match &t.op {
Op::Var(_, _) | Op::Const(_) => {
@@ -205,8 +210,8 @@ impl GraphWriter {
let t_id = self.get_tid_or_assign(&t);
for def in dug.use_defs.get(t).unwrap().iter() {
let def_id = self.get_tid_or_assign(&def);
if def_id != t_id{
if self.hyper_mode{
if def_id != t_id {
if self.hyper_mode {
self.insert_hyper_edge(&def_id, &t_id);
} else {
self.insert_edge(&def_id, &t_id);
@@ -221,20 +226,18 @@ impl GraphWriter {
self.term_to_id.clone()
}
pub fn write(&mut self, path: &String){
if self.hyper_mode{
pub fn write(&mut self, path: &String) {
if self.hyper_mode {
self.write_hyper_graph(path);
} else{
} else {
self.write_graph(path);
}
}
// Insert edge into PartitionGraph
fn insert_edge(&mut self, from: &usize, to: &usize) {
if !self.edges.contains_key(&from) {
self.edges
.insert(from.clone(), Edges { vec: Vec::new() });
self.edges.insert(from.clone(), Edges { vec: Vec::new() });
}
let added = self.edges.get_mut(&from).unwrap().add(*to);
if added {
@@ -244,22 +247,26 @@ impl GraphWriter {
// Insert hyper edge into PartitionGraph
fn insert_hyper_edge(&mut self, from: &usize, to: &usize) {
// Assume each node will only have one output
// Assume each node will only have one output
// TODO: fix this?
if !self.node_to_hyper_edge.contains_key(from) {
self.num_hyper_edges += 1;
let new_hyper_edge = HyperEdge {idx: self.num_hyper_edges};
let new_hyper_edge = HyperEdge {
idx: self.num_hyper_edges,
};
self.node_to_hyper_edge
.insert(from.clone(), new_hyper_edge.clone());
self.hyper_edges
.insert(new_hyper_edge.clone(), Edges { vec: Vec::new() });
// Add from node itself
self.hyper_edges.get_mut(&new_hyper_edge).unwrap().add(*from);
self.hyper_edges
.get_mut(&new_hyper_edge)
.unwrap()
.add(*from);
self.hyper_edges.get_mut(&new_hyper_edge).unwrap().add(*to);
} else{
} else {
let hyper_edge = self.node_to_hyper_edge.get(&from).unwrap();
self.hyper_edges.get_mut(&hyper_edge).unwrap().add(*to);
}
@@ -283,7 +290,7 @@ impl GraphWriter {
// for Nodes 1..N, write their neighbors
for i in 0..(self.num_nodes) {
let id = i+1;
let id = i + 1;
match self.edges.get(&id) {
Some(edges) => {
@@ -352,8 +359,7 @@ impl GraphWriter {
}
}
pub fn write_partition(path: &String, partition: &HashMap<usize, usize>){
pub fn write_partition(path: &String, partition: &HashMap<usize, usize>) {
if !Path::new(path).exists() {
println!("partition_path: {}", path);
fs::File::create(path).expect("Failed to create hyper graph file");
@@ -366,9 +372,8 @@ pub fn write_partition(path: &String, partition: &HashMap<usize, usize>){
// for Nodes 1..N, write their neighbors
for i in 0..partition.keys().len() {
let line = format!("{}\n", partition.get(&i).unwrap());
file.write_all(line.as_bytes())
.expect("Failed to write to graph file");
}
}
}

View File

@@ -1,6 +1,4 @@
pub mod part;
pub mod graph_utils;
#[cfg(feature = "lp")]
pub mod mutation;
pub mod graph_utils;
pub mod part;

View File

@@ -17,7 +17,6 @@ use crate::target::aby::assignment::def_uses::*;
use std::thread;
fn get_outer_n(cs: &ComputationSubgraph, n: usize) -> ComputationSubgraph {
let mut last_cs = cs.clone();
for _ in 0..n {
@@ -38,7 +37,7 @@ fn get_outer_n(cs: &ComputationSubgraph, n: usize) -> ComputationSubgraph {
/// Mutations with multi threading
fn mutate_partitions_mp_step(
cs: &HashMap<usize, ComputationSubgraph>,
cs: &HashMap<usize, ComputationSubgraph>,
cm: &str,
outer_level: usize,
step: usize,
@@ -66,7 +65,9 @@ fn mutate_partitions_mp_step(
let j = j.clone();
let c = c.clone();
let c_ref = c_ref.clone();
children.push(thread::spawn(move || (i, j, assign_mut(&c, &costm, &c_ref))));
children.push(thread::spawn(move || {
(i, j, assign_mut(&c, &costm, &c_ref))
}));
}
for child in children {
@@ -79,7 +80,7 @@ fn mutate_partitions_mp_step(
/// Mutations with multi threading
fn mutate_partitions_mp_step_smart(
dug: &DefUsesGraph,
dusg: &HashMap<usize, DefUsesSubGraph>,
dusg: &HashMap<usize, DefUsesSubGraph>,
cm: &str,
outer_level: usize,
step: usize,
@@ -87,8 +88,7 @@ fn mutate_partitions_mp_step_smart(
// TODO: merge and stop
let mut mut_smaps: HashMap<usize, HashMap<usize, SharingMap>> = HashMap::new();
let mut mut_sets: HashMap<(usize, usize), (DefUsesSubGraph, DefUsesSubGraph)> =
HashMap::new();
let mut mut_sets: HashMap<(usize, usize), (DefUsesSubGraph, DefUsesSubGraph)> = HashMap::new();
for (i, du) in dusg.iter() {
mut_smaps.insert(*i, HashMap::new());
@@ -107,7 +107,9 @@ fn mutate_partitions_mp_step_smart(
let j = j.clone();
let du = du.clone();
let du_ref = du_ref.clone();
children.push(thread::spawn(move || (i, j, assign_mut_smart(&du, &costm, &du_ref))));
children.push(thread::spawn(move || {
(i, j, assign_mut_smart(&du, &costm, &du_ref))
}));
}
for child in children {
@@ -117,7 +119,11 @@ fn mutate_partitions_mp_step_smart(
mut_smaps
}
fn get_global_assignments(cs: &Computation, term_to_part: &TermMap<usize>, local_smaps: &HashMap<usize, SharingMap>) -> SharingMap {
fn get_global_assignments(
cs: &Computation,
term_to_part: &TermMap<usize>,
local_smaps: &HashMap<usize, SharingMap>,
) -> SharingMap {
let mut global_smap: SharingMap = SharingMap::new();
let Computation { outputs, .. } = cs.clone();
@@ -138,7 +144,11 @@ fn get_global_assignments(cs: &Computation, term_to_part: &TermMap<usize>, local
global_smap
}
fn get_global_assignments_smart(dug: &DefUsesGraph, term_to_part: &TermMap<usize>, local_smaps: &HashMap<usize, SharingMap>) -> SharingMap {
fn get_global_assignments_smart(
dug: &DefUsesGraph,
term_to_part: &TermMap<usize>,
local_smaps: &HashMap<usize, SharingMap>,
) -> SharingMap {
let mut global_smap: SharingMap = SharingMap::new();
for t in dug.good_terms.iter() {
// get term partition assignment
@@ -153,15 +163,35 @@ fn get_global_assignments_smart(dug: &DefUsesGraph, term_to_part: &TermMap<usize
global_smap
}
pub fn get_share_map_with_mutation(cs: &Computation, cm: &str, partitions: &HashMap<usize, ComputationSubgraph>, term_to_part: &TermMap<usize>, mut_level: &usize, mut_step_size: &usize) -> SharingMap{
let mutation_smaps = mutate_partitions_mp_step(partitions, cm, mut_level.clone(), mut_step_size.clone());
pub fn get_share_map_with_mutation(
cs: &Computation,
cm: &str,
partitions: &HashMap<usize, ComputationSubgraph>,
term_to_part: &TermMap<usize>,
mut_level: &usize,
mut_step_size: &usize,
) -> SharingMap {
let mutation_smaps =
mutate_partitions_mp_step(partitions, cm, mut_level.clone(), mut_step_size.clone());
let selected_mut_maps = comb_selection(&mutation_smaps, &partitions, cm);
get_global_assignments(cs, term_to_part, &selected_mut_maps)
}
pub fn get_share_map_with_mutation_smart(dug: &DefUsesGraph, cm: &str, partitions: &HashMap<usize, DefUsesSubGraph>, term_to_part: &TermMap<usize>, mut_level: &usize, mut_step_size: &usize) -> SharingMap{
let mutation_smaps = mutate_partitions_mp_step_smart(dug, partitions, cm, mut_level.clone(), mut_step_size.clone());
pub fn get_share_map_with_mutation_smart(
dug: &DefUsesGraph,
cm: &str,
partitions: &HashMap<usize, DefUsesSubGraph>,
term_to_part: &TermMap<usize>,
mut_level: &usize,
mut_step_size: &usize,
) -> SharingMap {
let mutation_smaps = mutate_partitions_mp_step_smart(
dug,
partitions,
cm,
mut_level.clone(),
mut_step_size.clone(),
);
let selected_mut_maps = comb_selection_smart(dug, &mutation_smaps, &partitions, cm);
get_global_assignments_smart(dug, term_to_part, &selected_mut_maps)
}

View File

@@ -8,7 +8,6 @@ use std::process::{Command, Stdio};
use std::time::Instant;
pub struct Partitioner {
time_limit: usize,
imbalance: usize,
imbalance_f32: f32,
@@ -26,22 +25,38 @@ impl Partitioner {
graph
}
pub fn do_refinement(&self, graph_path: &String, input_part_path: &String, output_part_path: &String, num_parts: &usize) -> HashMap<usize, usize>{
if self.hyper_mode{
let part_path = format!("{}.part{}.epsilon{}.seed-1.KaHyPar", graph_path, num_parts, self.imbalance_f32.to_string());
pub fn do_refinement(
&self,
graph_path: &String,
input_part_path: &String,
output_part_path: &String,
num_parts: &usize,
) -> HashMap<usize, usize> {
if self.hyper_mode {
let part_path = format!(
"{}.part{}.epsilon{}.seed-1.KaHyPar",
graph_path,
num_parts,
self.imbalance_f32.to_string()
);
self.call_hyper_graph_refiner(graph_path, input_part_path, num_parts);
self.parse_partition(&part_path)
} else{
} else {
unimplemented!("Refinement using KaHIP not implemented. ");
}
}
pub fn do_partition(&self, graph_path: &String, num_parts: &usize) -> HashMap<usize, usize>{
if self.hyper_mode{
let part_path = format!("{}.part{}.epsilon{}.seed-1.KaHyPar", graph_path, num_parts, self.imbalance_f32.to_string());
pub fn do_partition(&self, graph_path: &String, num_parts: &usize) -> HashMap<usize, usize> {
if self.hyper_mode {
let part_path = format!(
"{}.part{}.epsilon{}.seed-1.KaHyPar",
graph_path,
num_parts,
self.imbalance_f32.to_string()
);
self.call_hyper_graph_partitioner(graph_path, num_parts);
self.parse_partition(&part_path)
} else{
} else {
self.check_graph(graph_path);
let part_path = format!("{}.part", graph_path);
self.call_graph_partitioner(graph_path, &part_path, num_parts);
@@ -58,13 +73,13 @@ impl Partitioner {
Ok(io::BufReader::new(file).lines())
}
fn parse_partition(&self, part_path: &String) -> HashMap<usize, usize>{
fn parse_partition(&self, part_path: &String) -> HashMap<usize, usize> {
let mut part_map = HashMap::new();
if let Ok(lines) = self.read_lines(part_path) {
for line in lines.into_iter().enumerate() {
if let (i, Ok(part)) = line {
let part_num = part.parse::<usize>().unwrap();
part_map.insert(i+1, part_num);
part_map.insert(i + 1, part_num);
}
}
}
@@ -94,7 +109,7 @@ impl Partitioner {
}
// Call graph partitioning algorithm on input graph
fn call_graph_partitioner(&self, graph_path: &String, part_path: &String, num_parts: &usize) {
fn call_graph_partitioner(&self, graph_path: &String, part_path: &String, num_parts: &usize) {
//TODO: fix path
let output = Command::new("../KaHIP/deploy/kaffpa")
.arg(graph_path)
@@ -115,7 +130,12 @@ impl Partitioner {
}
// Call hyper graph partitioning algorithm on input hyper graph
fn call_hyper_graph_refiner(&self, graph_path: &String, input_path: &String, num_parts: &usize) {
fn call_hyper_graph_refiner(
&self,
graph_path: &String,
input_path: &String,
num_parts: &usize,
) {
//TODO: fix path
let input_part_arg = format!("--part-file={}", input_path);
let output = Command::new("../kahypar/build/kahypar/application/KaHyPar")
@@ -150,5 +170,4 @@ impl Partitioner {
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(stdout.contains("The graph format seems correct."));
}
}