merge master

This commit is contained in:
Edward Chen
2022-05-16 11:34:14 -04:00
9 changed files with 129 additions and 128 deletions

View File

@@ -94,6 +94,8 @@ def test(features):
if cargo_features:
test_cmd = test_cmd + ["--features"] + cargo_features
subprocess.run(test_cmd, check=True)
if load_mode() == "release":
subprocess.run(test_cmd + ["--release"], check=True)
if "r1cs" in features and "smt" in features:
subprocess.run(["./scripts/test_datalog.zsh"], check=True)

View File

@@ -24,8 +24,8 @@ if __name__ == "__main__":
c_misc_tests + \
biomatch_tests + \
kmeans_tests + \
kmeans_tests_2 + \
db_tests
kmeans_tests_2
# db_tests
# gauss_tests + \
# TODO: add support for return value - int promotion

View File

@@ -113,7 +113,7 @@ mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c
mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c
mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c
mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c
mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss.c
# mpc_test_2 2 ./examples/C/mpc/benchmarks/mnist/2pc_mnist.c

View File

@@ -325,7 +325,7 @@ impl CGen {
FnInfo {
name,
ret_ty,
args: args,
args,
body,
}
}

View File

@@ -39,6 +39,8 @@ fn cbv(b: BitVector) -> Option<Term> {
/// Fold away operators over constants.
pub fn fold(node: &Term, ignore: &[Op]) -> Term {
// lock the collector before locking FOLDS (and, inside fold_cache, TERMS)
let _lock = super::super::term::COLLECT.read().unwrap();
let mut cache_handle = FOLDS.write().unwrap();
let cache = cache_handle.deref_mut();

View File

@@ -47,6 +47,8 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
scalarize_vars::scalarize_inputs(&mut cs);
}
Opt::ConstantFold(ignore) => {
// lock the collector because fold_cache locks TERMS
let _lock = super::term::COLLECT.read().unwrap();
let mut cache = TermCache::new(TERM_CACHE_LIMIT);
for a in &mut cs.outputs {
// allow unbounded size during a single fold_cache call

View File

@@ -1012,13 +1012,10 @@ impl TermTable {
}
debug!(target: "ir::term::gc", "{} of {} terms collected", old_size - new_size, old_size);
self.last_len = new_size;
super::opt::cfold::collect();
}
}
struct TypeTable {
map: FxHashMap<TTerm, Sort>,
last_len: usize,
}
impl std::ops::Deref for TypeTable {
type Target = FxHashMap<TTerm, Sort>;
@@ -1032,20 +1029,11 @@ impl std::ops::DerefMut for TypeTable {
}
}
impl TypeTable {
fn should_collect(&mut self) -> bool {
let ret = LEN_THRESH_DEN * self.map.len() > LEN_THRESH_NUM * self.last_len;
if self.last_len > TERM_CACHE_LIMIT {
// when last_len is big, force a garbage collect every once in a while
self.last_len = (self.last_len * LEN_DECAY_NUM) / LEN_DECAY_DEN;
}
ret
}
fn collect(&mut self) {
let old_size = self.map.len();
self.map.retain(|term, _| term.elm.strong_count() > 1);
let new_size = self.map.len();
debug!(target: "ir::term::gc", "{} of {} types collected", old_size - new_size, old_size);
self.last_len = new_size;
}
}
@@ -1057,6 +1045,23 @@ lazy_static! {
});
}
// Tests are executed concurrently, meaning that terms might be collected
// in one thread, breaking constant folding or type checking running in a
// different thread. To fix this, we add a lock that the collector takes
// read-write, and cfolding / type-checking takes read-only.
//
// Deadlock analysis:
// cfold takes FOLD_CACHE(w) -> TERMS(w)
// type checking takes TERM_TYPES(w)
// garbage collector takes one lock at a time
//
// The following locking priority MUST be observed:
//
// COLLECT -> FOLD_CACHE -> TERMS -> TERM_TYPES
lazy_static! {
pub(super) static ref COLLECT: RwLock<()> = RwLock::new(());
}
fn mk(elm: TermData) -> Term {
let mut slf = TERMS.write().unwrap();
slf.mk(elm)
@@ -1068,12 +1073,16 @@ pub fn garbage_collect() {
// this function may be called from Drop implementations, which are called
// when a thread is unwinding due to a panic. When that happens, RwLocks are
// poisoned, which would cause a panic-in-panic, no bueno.
if !std::thread::panicking() {
collect_terms();
collect_types();
} else {
if std::thread::panicking() {
log::warn!("Not garbage collecting because we are currently panicking.");
return;
}
// lock the collector before locking anything else
let _lock = COLLECT.write().unwrap();
collect_terms();
collect_types();
super::opt::cfold::collect();
}
const LEN_THRESH_NUM: usize = 8;
@@ -1085,23 +1094,23 @@ pub fn maybe_garbage_collect() -> bool {
// Don't garbage collect while panicking.
// NOTE This function probably shouldn't be called from Drop impls, but let's be safe anyway.
if std::thread::panicking() {
log::warn!("Not garbage collecting because we are currently panicking.");
return false;
}
let mut ran = {
// lock the collector before locking anything else
let _lock = COLLECT.write().unwrap();
let mut ran = false;
{
let mut term_table = TERMS.write().unwrap();
if term_table.should_collect() {
term_table.collect();
true
} else {
false
}
};
{
let mut type_table = ty::TERM_TYPES.write().unwrap();
if type_table.should_collect() {
type_table.collect();
ran = true;
}
} // TERMS lock goes out of scope here
if ran {
collect_types();
super::opt::cfold::collect();
}
ran
}

View File

@@ -6,7 +6,6 @@ lazy_static! {
/// Cache of all types
pub(super) static ref TERM_TYPES: RwLock<TypeTable> = RwLock::new(TypeTable {
map: FxHashMap::default(),
last_len: 0,
});
}
@@ -141,7 +140,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
Op::Update(_i) => Ok(get_ty(&t.cs[0]).clone()),
Op::Map(op) => {
let arg_cnt = t.cs.len();
let mut dterm_cs = Vec::new();
let mut arg_sorts_to_inner_op = Vec::new();
let mut key_sort = Sort::Bool;
let mut size = 0;
@@ -160,7 +159,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
for i in 0..arg_cnt {
match array_or(get_ty(&t.cs[i]), "map inputs") {
Ok((_, v)) => {
dterm_cs.push(v.default_term());
arg_sorts_to_inner_op.push(v);
}
Err(e) => {
error = Some(e);
@@ -170,12 +169,8 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
match error {
Some(e) => Err(e),
None => {
let term_ = term((**op).clone(), dterm_cs);
Ok(Sort::Array(
Box::new(key_sort),
Box::new(get_ty(&term_).clone()),
size,
))
let value_sort = rec_check_raw_helper(&**op, &arg_sorts_to_inner_op)?;
Ok(Sort::Array(Box::new(key_sort), Box::new(value_sort), size))
}
}
}
@@ -189,43 +184,39 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
if let Some(s) = TERM_TYPES.read().unwrap().get(&t.to_weak()) {
return Ok(s.clone());
}
{
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while !to_check.is_empty() {
let back = to_check.last_mut().unwrap();
let weak = back.0.to_weak();
// The idea here is to check that
if let Some((p, _)) = term_tys.get_key_value(&weak) {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
term_tys.remove(&weak);
}
}
if !back.1 {
back.1 = true;
for c in check_dependencies(&back.0) {
to_check.push((c, false));
}
// lock the collector before locking TERM_TYPES
let _lock = COLLECT.read().unwrap();
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while !to_check.is_empty() {
let back = to_check.last_mut().unwrap();
let weak = back.0.to_weak();
// The idea here is to check that
if let Some((p, _)) = term_tys.get_key_value(&weak) {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
let ty = check_raw_step(&back.0, &*term_tys).map_err(|reason| TypeError {
op: back.0.op.clone(),
args: vec![], // not quite right
reason,
})?;
term_tys.insert(back.0.to_weak(), ty);
term_tys.remove(&weak);
}
}
if !back.1 {
back.1 = true;
for c in check_dependencies(&back.0) {
to_check.push((c, false));
}
} else {
let ty = check_raw_step(&back.0, &*term_tys).map_err(|reason| TypeError {
op: back.0.op.clone(),
args: vec![], // not quite right
reason,
})?;
term_tys.insert(back.0.to_weak(), ty);
}
}
Ok(TERM_TYPES
.read()
.unwrap()
.get(&t.to_weak())
.unwrap()
.clone())
Ok(term_tys.get(&t.to_weak()).unwrap().clone())
}
/// Helper function for rec_check_raw
@@ -408,50 +399,45 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
if let Some(s) = TERM_TYPES.read().unwrap().get(&t.to_weak()) {
return Ok(s.clone());
}
{
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while !to_check.is_empty() {
let back = to_check.last_mut().unwrap();
let weak = back.0.to_weak();
// The idea here is to check that
if let Some((p, _)) = term_tys.get_key_value(&weak) {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
term_tys.remove(&weak);
}
}
if !back.1 {
back.1 = true;
for c in back.0.cs.clone() {
to_check.push((c, false));
}
// lock the collector before locking TERM_TYPES
let _lock = COLLECT.read().unwrap();
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while !to_check.is_empty() {
let back = to_check.last_mut().unwrap();
let weak = back.0.to_weak();
// The idea here is to check that
if let Some((p, _)) = term_tys.get_key_value(&weak) {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
let tys = back
.0
.cs
.iter()
.map(|c| term_tys.get(&c.to_weak()).unwrap())
.collect::<Vec<_>>();
let ty =
rec_check_raw_helper(&back.0.op, &tys[..]).map_err(|reason| TypeError {
op: back.0.op.clone(),
args: tys.into_iter().cloned().collect(),
reason,
})?;
term_tys.insert(back.0.to_weak(), ty);
term_tys.remove(&weak);
}
}
if !back.1 {
back.1 = true;
for c in back.0.cs.clone() {
to_check.push((c, false));
}
} else {
let tys = back
.0
.cs
.iter()
.map(|c| term_tys.get(&c.to_weak()).unwrap())
.collect::<Vec<_>>();
let ty = rec_check_raw_helper(&back.0.op, &tys[..]).map_err(|reason| TypeError {
op: back.0.op.clone(),
args: tys.into_iter().cloned().collect(),
reason,
})?;
term_tys.insert(back.0.to_weak(), ty);
}
}
Ok(TERM_TYPES
.read()
.unwrap()
.get(&t.to_weak())
.unwrap()
.clone())
Ok(term_tys.get(&t.to_weak()).unwrap().clone())
}
#[derive(Debug, PartialEq, Eq)]

View File

@@ -8,22 +8,22 @@ use zokrates_parser::Rule;
extern crate lazy_static;
pub use ast::{
Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement,
Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator,
BooleanLiteralExpression, BooleanType, CallAccess, ConstantDefinition, ConstantGenericValue,
Curve, DecimalLiteralExpression, DecimalNumber, DecimalSuffix, DefinitionStatement,
ExplicitGenerics, Expression, FieldSuffix, FieldType, File, FromExpression,
FromImportDirective, FunctionDefinition, HexLiteralExpression, HexNumberExpression,
IdentifierExpression, ImportDirective, AnyString, ImportSymbol, InlineArrayExpression,
InlineStructExpression, InlineStructMember, IterationStatement, LiteralExpression,
MainImportDirective, MemberAccess, NegOperator, NotOperator, Parameter, PosOperator,
PostfixExpression, Pragma, PrivateNumber, PrivateVisibility, PublicVisibility, Range,
RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StrOperator,
StructDefinition, StructField, StructType, SymbolDeclaration, TernaryExpression, ToExpression,
Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee, U16NumberExpression, U16Suffix, U16Type,
U32NumberExpression, U32Suffix, U32Type, U64NumberExpression, U64Suffix, U64Type,
U8NumberExpression, U8Suffix, U8Type, UnaryExpression, UnaryOperator, Underscore, Visibility,
EOI,
Access, AnyString, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType,
AssertionStatement, Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression,
BinaryOperator, BooleanLiteralExpression, BooleanType, CallAccess, ConstantDefinition,
ConstantGenericValue, Curve, DecimalLiteralExpression, DecimalNumber, DecimalSuffix,
DefinitionStatement, ExplicitGenerics, Expression, FieldSuffix, FieldType, File,
FromExpression, FromImportDirective, FunctionDefinition, HexLiteralExpression,
HexNumberExpression, IdentifierExpression, ImportDirective, ImportSymbol,
InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement,
LiteralExpression, MainImportDirective, MemberAccess, NegOperator, NotOperator, Parameter,
PosOperator, PostfixExpression, Pragma, PrivateNumber, PrivateVisibility, PublicVisibility,
Range, RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement,
StrOperator, StructDefinition, StructField, StructType, SymbolDeclaration, TernaryExpression,
ToExpression, Type, TypeDefinition, TypedIdentifier, TypedIdentifierOrAssignee,
U16NumberExpression, U16Suffix, U16Type, U32NumberExpression, U32Suffix, U32Type,
U64NumberExpression, U64Suffix, U64Type, U8NumberExpression, U8Suffix, U8Type, UnaryExpression,
UnaryOperator, Underscore, Visibility, EOI,
};
mod ast {