does kmeans work?

This commit is contained in:
Edward Chen
2021-11-15 20:50:14 -05:00
parent 082dd79617
commit ec8bf83558
12 changed files with 84 additions and 58 deletions

View File

@@ -84,10 +84,10 @@ int main(__attribute__((private(0))) int a[200], __attribute__((private(1))) int
int y2 = dy;
dist[i_9] = (x1-x2) * (x1-x2) + (y1 - y2) * (y1 - y2);
}
// // hardcoded NC = 5;
// // stride = 1
// // stride = 2
// // stride = 4
// hardcoded NC = 5;
// stride = 1
// stride = 2
// stride = 4
int stride = 1;
for(int i_10 = 0; i_10 < NC - stride; i_10+=2) {
if(dist[i_10+stride] < dist[i_10]) {

View File

@@ -1,17 +1,34 @@
int main(__attribute__((private(0))) int a, __attribute__((private(1))) int b) {
int NC = 2;
int pos[NC] = {6,7};
int dist[NC] = {6,7};
int main(__attribute__((private(0))) int a[5], __attribute__((private(1))) int b[5]) {
// int NC = 2;
// int pos[NC] = {1,2,3,4};
// int dist[NC] = {5,6,7,8};
int stride = 1;
for(int i_10 = 0; i_10 < NC - stride; i_10+=2) {
if(dist[i_10+stride] < dist[i_10]) {
dist[i_10] = dist[i_10+stride];
pos[i_10] = pos[i_10+stride];
}
// int stride = 1;
// for(int i_10 = 0; i_10 < NC - stride; i_10+=2) {
// if (dist[i_10+stride] < dist[i_10]) {
// dist[i_10] = dist[i_10+stride];
// pos[i_10] = pos[i_10+stride];
// }
// }
// if 1 < 0: false
// if 2 < 3: false
// 5
// int c = 0;
// if (c == 1) {
// c = 1;
// }
// return c;
int c[1];
if (c[0] == 1) {
c[0] = 1;
}
return dist[0];
return c[0];
}

View File

@@ -62,6 +62,7 @@ fn main() {
let cs = match mode {
Mode::Mpc(_) => opt(
cs,
// vec![],
vec![Opt::Sha, Opt::ConstantFold, Opt::Mem, Opt::ConstantFold],
),
_ => unimplemented!(),

View File

@@ -16,8 +16,8 @@ if __name__ == "__main__":
# c_array_tests + \
# div_tests
tests = kmeans_tests + div_tests
# tests = kmeans_tests
# tests = kmeans_tests + div_tests
tests = kmeans_tests
# tests = div_tests
# TODO: add support for return value - int promotion

View File

@@ -76,7 +76,7 @@ function mpc_test {
# build div tests
# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c
mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div_2.c
# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div_2.c

View File

@@ -135,12 +135,14 @@ impl MemManager {
}
/// Write the value `val` to index `offset` in the allocation `id`.
pub fn store(&mut self, id: AllocId, offset: Term, val: Term) {
pub fn store(&mut self, id: AllocId, offset: Term, val: Term, cond: Term) {
let alloc = self.allocs.get_mut(&id).expect("Missing allocation");
assert_eq!(alloc.addr_width, check(&offset).as_bv());
assert_eq!(alloc.val_width, check(&val).as_bv());
let new = term![Op::Store; alloc.var().clone(), offset, val];
alloc.cur_term = new;
let old = alloc.cur_term.clone();
let new = term![Op::Store; alloc.var().clone(), offset.clone(), val];
let ite_store = term![Op::Ite; cond, new, old];
alloc.cur_term = ite_store;
// alloc.next_var();
// let v = alloc.var().clone();

View File

@@ -1,4 +1,5 @@
//! A library for building front-ends
use crate::circify::mem::AllocId;
use crate::ir::term::*;
use std::cell::RefCell;
@@ -787,6 +788,22 @@ impl<E: Embeddable> Circify<E> {
pub fn consume(self) -> Rc<RefCell<Computation>> {
self.cir_ctx.cs
}
/// Load from an AllocID
pub fn load(&self, id: AllocId, offset: Term) -> Term {
self.cir_ctx.mem.borrow_mut().load(id, offset)
}
/// Store to an AllocID
pub fn store(&mut self, id: AllocId, offset: Term, val: Term) {
let cond = self.condition();
self.cir_ctx.mem.borrow_mut().store(id, offset, val, cond);
}
/// Zero allocate an array
pub fn zero_allocate(&mut self, size: usize, addr_width: usize, val_width: usize) -> AllocId {
self.cir_ctx.mem.borrow_mut().zero_allocate(size, addr_width, val_width)
}
}
const RET_NAME: &str = "return";

View File

@@ -7,7 +7,6 @@ mod types;
use super::FrontEnd;
use crate::circify::{Circify, Loc, Val};
use crate::circify::mem::MemManager;
use crate::front::c::ast_utils::*;
use crate::front::c::term::*;
use crate::front::c::types::*;
@@ -19,7 +18,6 @@ use lang_c::span::Node;
use log::debug;
// use std::collections::HashMap;
use std::cell::RefMut;
use std::fmt::{self, Display, Formatter};
use std::path::PathBuf;
@@ -127,22 +125,17 @@ impl CGen {
r.unwrap_or_else(|e| self.err(e))
}
fn get_mem(&self) -> RefMut<MemManager> {
self.circ.cir_ctx().mem.borrow_mut()
}
fn array_select(&self, array: CTerm, idx: CTerm) -> Result<CTerm, String> {
let mem = self.get_mem();
match (array.clone().term, idx.term) {
(CTermData::CArray(ty, id), CTermData::CInt(_, _, idx)) => {
let i = id.unwrap_or_else(|| panic!("Unknown AllocID: {:#?}", array));
Ok(CTerm {
term: match ty {
Ty::Bool => {
CTermData::CBool(mem.load(i, idx))
CTermData::CBool(self.circ.load(i, idx))
}
Ty::Int(s,w) => {
CTermData::CInt(s, w, mem.load(i, idx))
CTermData::CInt(s, w, self.circ.load(i, idx))
}
// TODO: Flatten array so this case doesn't occur
// Ty::Array(_,t) => {
@@ -157,13 +150,12 @@ impl CGen {
}
}
pub fn array_store(&self, array: CTerm, idx: CTerm, val: CTerm) -> Result<CTerm, String> {
pub fn array_store(&mut self, array: CTerm, idx: CTerm, val: CTerm) -> Result<CTerm, String> {
match (array.clone().term, idx.term) {
(CTermData::CArray(_, id), CTermData::CInt(_, _, idx_term)) => {
let i = id.unwrap_or_else(|| panic!("Unknown AllocID: {:#?}", array.clone()));
let mut mem = self.get_mem();
let new_val = val.term.term(&mem);
mem.store(i, idx_term, new_val);
let new_val = val.term.term(&self.circ);
self.circ.store(i, idx_term, new_val);
Ok(val.clone())
}
(a, b) => Err(format!("[Array Store] cannot index {} by {}", b, a)),
@@ -212,8 +204,7 @@ impl CGen {
}
fn fold_(&mut self, expr: CTerm) -> i32 {
let mem = self.get_mem();
let term_ = fold(&expr.term.term(&mem));
let term_ = fold(&expr.term.term(&self.circ));
let cterm_ = CTerm {
term: CTermData::CInt(true, 32, term_),
udef: false,
@@ -378,14 +369,13 @@ impl CGen {
// TODO: fix hack, const int check for shifting
if f == shl || f == shr {
let mem = self.get_mem();
let a_t = fold(&a.term.term(&mem));
let a_t = fold(&a.term.term(&self.circ));
a = CTerm {
term: CTermData::CInt(true, 32, a_t),
udef: false,
};
let b_t = fold(&b.term.term(&mem));
let b_t = fold(&b.term.term(&self.circ));
b = CTerm {
term: CTermData::CInt(true, 32, b_t),
udef: false,
@@ -439,13 +429,12 @@ impl CGen {
let expr = self.gen_init(inner_type.clone(), li.node.initializer.node.clone());
values.push(expr)
}
let mut mem = self.get_mem();
let id = mem.zero_allocate(values.len(), 32, num_bits(inner_type.clone()));
let id = self.circ.zero_allocate(values.len(), 32, num_bits(inner_type.clone()));
for (i,v) in values.iter().enumerate() {
let offset = bv_lit(i, 32);
let v_ = v.term.term(&mem);
mem.store(id, offset, v_);
let v_ = v.term.term(&self.circ);
self.circ.store(id, offset, v_);
}
CTerm {
@@ -468,8 +457,7 @@ impl CGen {
} else {
expr = match derived_ty {
Ty::Array(size, ref ty) => {
let mut mem = self.get_mem();
let id = mem.zero_allocate(size.unwrap(), 32, num_bits(*ty.clone()));
let id = self.circ.zero_allocate(size.unwrap(), 32, num_bits(*ty.clone()));
CTerm {
term: CTermData::CArray(*ty.clone(), Some(id)),
udef: false,
@@ -634,15 +622,14 @@ impl CGen {
}
Statement::If(node) => {
let cond = self.gen_expr(node.node.condition.node);
// TODO Cast to boolean for condition ;
let t_term = cond.term.term(&self.get_mem());
let t_term = cond.term.term(&self.circ);
let t_res = self.circ.enter_condition(t_term);
self.unwrap(t_res);
self.gen_stmt(node.node.then_statement.node);
self.circ.exit_condition();
if let Some(f_cond) = node.node.else_statement {
let f_term = term!(Op::Not; cond.term.term(&self.get_mem()));
let f_term = term!(Op::Not; cond.term.term(&self.circ));
let f_res = self.circ.enter_condition(f_term);
self.unwrap(f_res);
self.gen_stmt(f_cond.node);
@@ -701,14 +688,15 @@ impl CGen {
let fn_info = ast_utils::get_fn_info(&fn_def.node);
self.circ.enter_fn(fn_info.name.to_owned(), fn_info.ret_ty);
for arg in fn_info.args.iter() {
// TODO: self.gen_decl(arg);
let p = &arg.specifiers[0];
let vis = self.interpret_visibility(&p.node);
let base_ty = d_type_(arg.specifiers[1..].to_vec());
let d = &arg.declarator.as_ref().unwrap().node;
let derived_ty = self.derived_type_(base_ty.unwrap(), d.derived.to_vec());
let name = name_from_decl(d);
let r = self.circ.declare(name.clone(), &derived_ty, true, vis);
self.unwrap(r);
let res = self.circ.declare(name.clone(), &derived_ty, true, vis);
self.unwrap(res);
}
self.gen_stmt(fn_info.body.clone());
if let Some(r) = self.circ.exit_fn() {

View File

@@ -1,11 +1,11 @@
//! C Terms
use crate::circify::{CirCtx, Embeddable};
use crate::circify::mem::{AllocId, MemManager};
use crate::circify::mem::AllocId;
use crate::front::c::is_signed_int;
use crate::front::c::Circify;
use crate::front::c::types::*;
use crate::ir::term::*;
use rug::Integer;
use std::cell::RefMut;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
@@ -39,14 +39,14 @@ impl CTermData {
terms_tail(self, &mut output);
output
}
pub fn term(&self, mem: &RefMut<MemManager>) -> Term {
pub fn term(&self, circ: &Circify<Ct>) -> Term {
match self {
CTermData::CBool(b) => b.clone(),
CTermData::CInt(_, _, b) => b.clone(),
CTermData::CArray(_,b) => {
// TODO: load all of the array
let i = b.unwrap_or_else(|| panic!("Unknown AllocID: {:#?}", self));
mem.load(i, bv_lit(0,32))
circ.load(i, bv_lit(0,32))
},
}
}
@@ -507,8 +507,9 @@ impl Embeddable for Ct {
udef: false,
};
for (i, t) in v.iter().enumerate() {
let val = t.term.term(&mem);
mem.store(id, bv_lit(i, 32), val);
let val = t.term.terms()[0].clone();
let t_term = leaf_term(Op::Const(Value::Bool(true)));
mem.store(id, bv_lit(i, 32), val, t_term);
}
arr
},

View File

@@ -609,7 +609,7 @@ pub fn to_aby(ir: Computation) -> ABY {
let mut converter = ToABY::new(md, s_map);
for t in terms {
println!("terms: {}", t);
// println!("terms: {}", t);
converter.lower(t.clone());
}