From da7aa75cfbb06534deafc297684c71865afaa265 Mon Sep 17 00:00:00 2001 From: Edward Chen Date: Tue, 21 Jun 2022 18:24:02 -0400 Subject: [PATCH] added struct semantics --- examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c | 267 +++++++++--------- scripts/build_mpc_c_test.zsh | 2 +- src/ir/opt/mem/obliv.rs | 3 +- src/ir/opt/visit.rs | 4 +- src/target/aby/assignment/mod.rs | 4 +- src/target/aby/trans.rs | 81 +++++- 6 files changed, 209 insertions(+), 152 deletions(-) diff --git a/examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c index 648dddf8..a000925c 100644 --- a/examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c +++ b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c @@ -33,154 +33,157 @@ int main(__attribute__((private(0))) int a[20], __attribute__((private(1))) int cluster[i_2 * D + 1] = data[((i_2 + 3) % LEN) * D + 1]; } - // for (int i_3 = 0; i_3 < PRECISION; i_3++) - // { - // int new_cluster[D * NC]; + for (int i_3 = 0; i_3 < PRECISION; i_3++) + { + int new_cluster[D * NC]; - // // ======================= iteration_unrolled_outer - // int count[NC]; + // ======================= iteration_unrolled_outer + int count[NC]; - // // Set Outer result - // for (int i_4 = 0; i_4 < NC; i_4++) - // { - // new_cluste> + // Set Outer result + for (int i_4 = 0; i_4 < NC; i_4++) + { + new_cluster[i_4 * D] = 0; + new_cluster[i_4 * D + 1] = 0; + count[i_4] = 0; + } - // int loop_clusterD1[NC * LEN_OUTER]; - // int loop_clusterD2[NC * LEN_OUTER]; - // int loop_count[NC * LEN_OUTER]; + int loop_clusterD1[NC * LEN_OUTER]; + int loop_clusterD2[NC * LEN_OUTER]; + int loop_count[NC * LEN_OUTER]; - // // Compute decomposition - // for (int i_5 = 0; i_5 < LEN_OUTER; i_5++) - // { - // // Copy data, fasthack for scalability - // int data_offset = i_5 * LEN_INNER * D; - // int data_inner[LEN_INNER * D]; + // Compute decomposition + for (int i_5 = 0; i_5 < LEN_OUTER; i_5++) + { + // Copy data, fasthack for scalability + int data_offset = i_5 * LEN_INNER * D; + int data_inner[LEN_INNER * D]; - // // memcpy(data_inner, data+data_offset, LEN_INNER*D*sizeof(coord_t)); - // for (int i_6 = 0; i_6 < LEN_INNER * D; i_6++) - // { - // data_inner[i_6] = data[i_6 + data_offset]; - // } + // memcpy(data_inner, data+data_offset, LEN_INNER*D*sizeof(coord_t)); + for (int i_6 = 0; i_6 < LEN_INNER * D; i_6++) + { + data_inner[i_6] = data[i_6 + data_offset]; + } - // int cluster_inner[NC * D]; - // int count_inner[NC]; + int cluster_inner[NC * D]; + int count_inner[NC]; - // // ======================= iteration_unrolled_inner_depth(data_inner, cluster, cluster_inner, count_inner, LEN_INNER, NC); - // int dist[NC]; - // int pos[NC]; - // int bestMap_inner[LEN_INNER]; + // ======================= iteration_unrolled_inner_depth(data_inner, cluster, cluster_inner, count_inner, LEN_INNER, NC); + int dist[NC]; + int pos[NC]; + int bestMap_inner[LEN_INNER]; - // for (int i_7 = 0; i_7 < NC; i_7++) - // { - // cluster_inner[i_7 * D] = 0; - // cluster_inner[i_7 * D + 1] = 0; - // count_inner[i_7] = 0; - // } + for (int i_7 = 0; i_7 < NC; i_7++) + { + cluster_inner[i_7 * D] = 0; + cluster_inner[i_7 * D + 1] = 0; + count_inner[i_7] = 0; + } - // // Compute nearest clusters for Data item i - // for (int i_8 = 0; i_8 < LEN_INNER; i_8++) - // { - // int dx = data_inner[i_8 * D]; - // int dy = data_inner[i_8 * D + 1]; + // Compute nearest clusters for Data item i + for (int i_8 = 0; i_8 < LEN_INNER; i_8++) + { + int dx = data_inner[i_8 * D]; + int dy = data_inner[i_8 * D + 1]; - // for (int i_9 = 0; i_9 < NC; i_9++) - // { - // pos[i_9] = i_9; - // int x1 = cluster[D * i_9]; - // int y1 = cluster[D * i_9 + 1]; - // int x2 = dx; - // int y2 = dy; - // dist[i_9] = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2); - // } - // // hardcoded NC = 5; - // // stride = 1 - // // stride = 2 - // // stride = 4 - // int stride = 1; - // for (int i_10 = 0; i_10 < NC - stride; i_10 += 2) - // { - // if (dist[i_10 + stride] < dist[i_10]) - // { - // dist[i_10] = dist[i_10 + stride]; - // pos[i_10] = pos[i_10 + stride]; - // } - // } - // stride = 2; - // for (int i_11 = 0; i_11 < NC - stride; i_11 += 4) - // { - // if (dist[i_11 + stride] < dist[i_11]) - // { - // dist[i_11] = dist[i_11 + stride]; - // pos[i_11] = pos[i_11 + stride]; - // } - // } - // stride = 4; - // for (int i_12 = 0; i_12 < NC - stride; i_12 += 8) - // { - // if (dist[i_12 + stride] < dist[i_12]) - // { - // dist[i_12] = dist[i_12 + stride]; - // pos[i_12] = pos[i_12 + stride]; - // } - // } - // bestMap_inner[i_8] = pos[0]; - // int cc = bestMap_inner[i_8]; - // cluster_inner[cc * D] += data_inner[i_8 * D]; - // cluster_inner[cc * D + 1] += data_inner[i_8 * D + 1]; - // count_inner[cc] += 1; - // } - // // // ======================= iteration_unrolled_inner_depth(data_inner, cluster, cluster_inner, count_inner, LEN_INNER, NC); + for (int i_9 = 0; i_9 < NC; i_9++) + { + pos[i_9] = i_9; + int x1 = cluster[D * i_9]; + int y1 = cluster[D * i_9 + 1]; + int x2 = dx; + int y2 = dy; + dist[i_9] = (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2); + } + // hardcoded NC = 5; + // stride = 1 + // stride = 2 + // stride = 4 + int stride = 1; + for (int i_10 = 0; i_10 < NC - stride; i_10 += 2) + { + if (dist[i_10 + stride] < dist[i_10]) + { + dist[i_10] = dist[i_10 + stride]; + pos[i_10] = pos[i_10 + stride]; + } + } + stride = 2; + for (int i_11 = 0; i_11 < NC - stride; i_11 += 4) + { + if (dist[i_11 + stride] < dist[i_11]) + { + dist[i_11] = dist[i_11 + stride]; + pos[i_11] = pos[i_11 + stride]; + } + } + stride = 4; + for (int i_12 = 0; i_12 < NC - stride; i_12 += 8) + { + if (dist[i_12 + stride] < dist[i_12]) + { + dist[i_12] = dist[i_12 + stride]; + pos[i_12] = pos[i_12 + stride]; + } + } + bestMap_inner[i_8] = pos[0]; + int cc = bestMap_inner[i_8]; + cluster_inner[cc * D] += data_inner[i_8 * D]; + cluster_inner[cc * D + 1] += data_inner[i_8 * D + 1]; + count_inner[cc] += 1; + } + // ======================= iteration_unrolled_inner_depth(data_inner, cluster, cluster_inner, count_inner, LEN_INNER, NC); - // for (int i_13 = 0; i_13 < NC; i_13++) - // { - // loop_clusterD1[i_13 * LEN_OUTER + i_5] = cluster_inner[i_13 * D]; - // loop_clusterD2[i_13 * LEN_OUTER + i_5] = cluster_inner[i_13 * D + 1]; - // loop_count[i_13 * LEN_OUTER + i_5] = count_inner[i_13]; - // } - // } + for (int i_13 = 0; i_13 < NC; i_13++) + { + loop_clusterD1[i_13 * LEN_OUTER + i_5] = cluster_inner[i_13 * D]; + loop_clusterD2[i_13 * LEN_OUTER + i_5] = cluster_inner[i_13 * D + 1]; + loop_count[i_13 * LEN_OUTER + i_5] = count_inner[i_13]; + } + } - // for (int i_14 = 0; i_14 < NC; i_14++) - // { - // new_cluster[i_14 * D] = - // loop_clusterD1[i_14 * LEN_OUTER + 0] + loop_clusterD1[i_14 * LEN_OUTER + 1] + - // loop_clusterD1[i_14 * LEN_OUTER + 2] + loop_clusterD1[i_14 * LEN_OUTER + 3] + - // loop_clusterD1[i_14 * LEN_OUTER + 4] + loop_clusterD1[i_14 * LEN_OUTER + 5] + - // loop_clusterD1[i_14 * LEN_OUTER + 6] + loop_clusterD1[i_14 * LEN_OUTER + 7] + - // loop_clusterD1[i_14 * LEN_OUTER + 8] + loop_clusterD1[i_14 * LEN_OUTER + 9]; + for (int i_14 = 0; i_14 < NC; i_14++) + { + new_cluster[i_14 * D] = + loop_clusterD1[i_14 * LEN_OUTER + 0] + loop_clusterD1[i_14 * LEN_OUTER + 1] + + loop_clusterD1[i_14 * LEN_OUTER + 2] + loop_clusterD1[i_14 * LEN_OUTER + 3] + + loop_clusterD1[i_14 * LEN_OUTER + 4] + loop_clusterD1[i_14 * LEN_OUTER + 5] + + loop_clusterD1[i_14 * LEN_OUTER + 6] + loop_clusterD1[i_14 * LEN_OUTER + 7] + + loop_clusterD1[i_14 * LEN_OUTER + 8] + loop_clusterD1[i_14 * LEN_OUTER + 9]; - // new_cluster[i_14 * D + 1] = - // loop_clusterD2[i_14 * LEN_OUTER + 0] + loop_clusterD2[i_14 * LEN_OUTER + 1] + - // loop_clusterD2[i_14 * LEN_OUTER + 2] + loop_clusterD2[i_14 * LEN_OUTER + 3] + - // loop_clusterD2[i_14 * LEN_OUTER + 4] + loop_clusterD2[i_14 * LEN_OUTER + 5] + - // loop_clusterD2[i_14 * LEN_OUTER + 6] + loop_clusterD2[i_14 * LEN_OUTER + 7] + - // loop_clusterD2[i_14 * LEN_OUTER + 8] + loop_clusterD2[i_14 * LEN_OUTER + 9]; + // new_cluster[i_14 * D + 1] = + // loop_clusterD2[i_14 * LEN_OUTER + 0] + loop_clusterD2[i_14 * LEN_OUTER + 1] + + // loop_clusterD2[i_14 * LEN_OUTER + 2] + loop_clusterD2[i_14 * LEN_OUTER + 3] + + // loop_clusterD2[i_14 * LEN_OUTER + 4] + loop_clusterD2[i_14 * LEN_OUTER + 5] + + // loop_clusterD2[i_14 * LEN_OUTER + 6] + loop_clusterD2[i_14 * LEN_OUTER + 7] + + // loop_clusterD2[i_14 * LEN_OUTER + 8] + loop_clusterD2[i_14 * LEN_OUTER + 9]; - // count[i_14] = - // loop_count[i_14 * LEN_OUTER + 0] + loop_count[i_14 * LEN_OUTER + 1] + - // loop_count[i_14 * LEN_OUTER + 2] + loop_count[i_14 * LEN_OUTER + 3] + - // loop_count[i_14 * LEN_OUTER + 4] + loop_count[i_14 * LEN_OUTER + 5] + - // loop_count[i_14 * LEN_OUTER + 6] + loop_count[i_14 * LEN_OUTER + 7] + - // loop_count[i_14 * LEN_OUTER + 8] + loop_count[i_14 * LEN_OUTER + 9]; - // } + // count[i_14] = + // loop_count[i_14 * LEN_OUTER + 0] + loop_count[i_14 * LEN_OUTER + 1] + + // loop_count[i_14 * LEN_OUTER + 2] + loop_count[i_14 * LEN_OUTER + 3] + + // loop_count[i_14 * LEN_OUTER + 4] + loop_count[i_14 * LEN_OUTER + 5] + + // loop_count[i_14 * LEN_OUTER + 6] + loop_count[i_14 * LEN_OUTER + 7] + + // loop_count[i_14 * LEN_OUTER + 8] + loop_count[i_14 * LEN_OUTER + 9]; + } - // // Recompute cluster Pos - // // Compute mean - // for (int i_15 = 0; i_15 < NC; i_15++) - // { - // if (count[i_15] > 0) - // { - // new_cluster[i_15 * D] /= count[i_15]; - // new_cluster[i_15 * D + 1] /= count[i_15]; - // } - // } - // // ======================= iteration_unrolled_outer + // Recompute cluster Pos + // Compute mean + for (int i_15 = 0; i_15 < NC; i_15++) + { + if (count[i_15] > 0) + { + new_cluster[i_15 * D] /= count[i_15]; + new_cluster[i_15 * D + 1] /= count[i_15]; + } + } + // ======================= iteration_unrolled_outer - // // We need to copy inputs to outputs - // for (int i_16 = 0; i_16 < NC * D; i_16++) - // { - // cluster[i_16] = new_cluster[i_16]; - // } - // } + // We need to copy inputs to outputs + for (int i_16 = 0; i_16 < NC * D; i_16++) + { + cluster[i_16] = new_cluster[i_16]; + } + } for (int i_17 = 0; i_17 < NC; i_17++) { output[i_17 * D] = cluster[i_17 * D]; diff --git a/scripts/build_mpc_c_test.zsh b/scripts/build_mpc_c_test.zsh index 14cc1d58..ecd7221d 100755 --- a/scripts/build_mpc_c_test.zsh +++ b/scripts/build_mpc_c_test.zsh @@ -30,7 +30,7 @@ function mpc_test_2 { RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+b" } -mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c +mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c # mpc_test_2 2 ./examples/C/mpc/playground.c # mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c diff --git a/src/ir/opt/mem/obliv.rs b/src/ir/opt/mem/obliv.rs index 8ba216d4..d0b4b04e 100644 --- a/src/ir/opt/mem/obliv.rs +++ b/src/ir/opt/mem/obliv.rs @@ -161,6 +161,7 @@ struct Replacer { impl Replacer { fn is_obliv(&self, a: &Term) -> bool { + println!("is obliv?"); !self.not_obliv.contains(a) } } @@ -168,7 +169,7 @@ impl Replacer { #[track_caller] fn get_const(t: &Term) -> usize { as_uint_constant(t) - .unwrap_or_else(|| panic!("non-const {}", t)) + .unwrap_or_else(|| panic!("non-const: {}", t)) .to_usize() .expect("oversize") } diff --git a/src/ir/opt/visit.rs b/src/ir/opt/visit.rs index 3031dd66..a25973a5 100644 --- a/src/ir/opt/visit.rs +++ b/src/ir/opt/visit.rs @@ -3,7 +3,7 @@ use crate::ir::term::*; /// A rewriting pass. pub trait RewritePass { /// Visit (and possibly rewrite) a term. - /// Given the original term and a function to get its rewritten childen. + /// Given the original term and a function to get its rewritten children. /// Returns a term if a rewrite happens. fn visit Vec>( &mut self, @@ -44,7 +44,7 @@ pub trait RewritePass { } } -/// An analysis pass that repeated sweeps all terms, visiting them, untill a pass makes no more +/// An analysis pass that repeated sweeps all terms, visiting them, until a pass makes no more /// progress. pub trait ProgressAnalysisPass { /// The visit function. Returns whether progress was made. diff --git a/src/target/aby/assignment/mod.rs b/src/target/aby/assignment/mod.rs index 682776da..3b2a8819 100644 --- a/src/target/aby/assignment/mod.rs +++ b/src/target/aby/assignment/mod.rs @@ -120,7 +120,9 @@ impl CostModel { fn get(&self, op: &Op) -> Option<&FxHashMap> { match op { - Op::Field(_) | Op::Call(..) | Op::Update(..) => Some(&self.zero), + Op::Field(_) | Op::Update(..) | Op::Select | Op::Store | Op::Call(..) => { + Some(&self.zero) + } _ => { let op_name = match op.clone() { // assume comparisions are unsigned diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index 1fb4a444..76c2a9d5 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -7,6 +7,7 @@ use rug::Integer; use crate::ir::opt::cfold::fold; +use crate::ir::opt::tuple; use crate::ir::term::*; #[cfg(feature = "lp")] use crate::target::aby::assignment::ilp::assign; @@ -506,15 +507,14 @@ impl<'a> ToABY<'a> { assert!(t.cs.len() == 2); let shares = self.get_shares(&t.cs[0]); - let idx: usize = if let Op::Const(v) = &t.cs[1].op { - match v { - Value::BitVector(b) => b.uint().to_usize().unwrap(), - _ => todo!(), - } - } else { - todo!("Non-constant index into an array"); + // Assume constant indexing + let idx = match &t.cs[1].op { + Op::Const(Value::BitVector(bv)) => bv.uint().to_usize().unwrap().clone(), + _ => panic!("non-const"), }; + assert!(idx < shares.len(), "idx: {}, shares: {}", idx, shares.len()); + self.term_to_shares.insert(t.clone(), vec![shares[idx]]); self.cache.insert(t.clone(), EmbeddedTerm::Bv); } @@ -525,14 +525,16 @@ impl<'a> ToABY<'a> { fn embed_scalar(&mut self, t: Term) { match &t.op { Op::Const(Value::Array(arr)) => { - for i in 0..arr.size { + let shares = self.get_shares(&t); + assert!(shares.len() == arr.size); + + for (i, s) in shares.iter().enumerate() { // TODO: sort of index might not be a 32-bit bitvector let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); let v = match arr.map.get(&idx) { Some(c) => c, None => &*arr.default, }; - let s = self.get_share(&t); match v { Value::BitVector(b) => { @@ -560,10 +562,58 @@ impl<'a> ToABY<'a> { } } } + Op::Store => { + assert!(t.cs.len() == 3); + let mut array_shares = self.get_shares(&t.cs[0]); + let value_share = self.get_share(&t.cs[2]); + + // Assume constant indexing + let idx = match &t.cs[1].op { + Op::Const(Value::BitVector(bv)) => bv.uint().to_usize().unwrap().clone(), + _ => panic!("non-const"), + }; + + array_shares[idx] = value_share; + self.term_to_shares.insert(t.clone(), array_shares); + self.cache.insert(t.clone(), EmbeddedTerm::Array); + } Op::Field(i) => { assert!(t.cs.len() == 1); let shares = self.get_shares(&t.cs[0]); - self.term_to_shares.insert(t.clone(), vec![shares[*i]]); + + let tuple_sort = check(&t.cs[0]); + let (offset, len) = match tuple_sort { + Sort::Tuple(t) => { + assert!(*i < t.len()); + + // find offset + let mut offset = 0; + for j in 0..*i { + match t[j] { + Sort::BitVector(_) => offset += 1, + Sort::Bool => offset += 1, + Sort::Array(_, _, size) => offset += size, + _ => todo!(), + } + } + + // find len + let len = match t[*i] { + Sort::BitVector(_) => 1, + Sort::Bool => 1, + Sort::Array(_, _, size) => size, + _ => todo!(), + }; + + (offset, len) + } + _ => panic!("Field op on non-tuple"), + }; + + // get ret slice + let field_shares = &shares[offset..offset + len]; + + self.term_to_shares.insert(t.clone(), field_shares.to_vec()); self.cache.insert(t.clone(), EmbeddedTerm::Array); } Op::Update(i) => { @@ -590,7 +640,7 @@ impl<'a> ToABY<'a> { // map argument shares let mut arg_shares: Vec = Vec::new(); for c in t.cs.iter() { - arg_shares.push(self.get_share(c)); + arg_shares.extend(self.get_shares(c)); } let arg_shares_str: String = arg_shares.iter().map(|&s| s.to_string() + " ").collect(); @@ -652,10 +702,11 @@ impl<'a> ToABY<'a> { } let op = "OUT"; - let s = self.get_share(&t); - let line = format!("1 0 {} {}\n", s, op); - - outputs.push(line); + let shares = self.get_shares(&t); + for s in shares { + let line = format!("1 0 {} {}\n", s, op); + outputs.push(line); + } } // write bytecode per function