Resolve lints and add clippy to CI (#35)

`front::zokrates` is currently excluded
This commit is contained in:
Alex Ozdemir
2022-01-01 12:27:36 -08:00
committed by GitHub
parent f2744e0c06
commit a9fd7888c4
37 changed files with 351 additions and 360 deletions

View File

@@ -29,6 +29,8 @@ jobs:
run: cargo check --verbose
- name: Check format
run: cargo fmt -- --check
- name: Lint
run: cargo clippy
- name: Build
run: cargo build --verbose && make build
- name: Run tests

View File

@@ -21,3 +21,6 @@ clean:
format:
cargo fmt --all
lint:
cargo clippy

View File

@@ -98,7 +98,7 @@ impl MemManager {
let alloc = Alloc::new(id, *addr_width, *val_width, *size);
let v = alloc.var().clone();
if let Op::Var(n, _) = &v.op {
self.cs.borrow_mut().eval_and_save(&n, &array);
self.cs.borrow_mut().eval_and_save(n, &array);
} else {
unreachable!()
}
@@ -152,7 +152,7 @@ impl MemManager {
alloc.next_var();
let v = alloc.var().clone();
if let Op::Var(n, _) = &v.op {
self.cs.borrow_mut().eval_and_save(&n, &new);
self.cs.borrow_mut().eval_and_save(n, &new);
} else {
unreachable!()
}

View File

@@ -267,7 +267,8 @@ impl<Ty: Display> FnFrame<Ty> {
}
fn enter_condition(&mut self, condition: Term) -> Result<()> {
if check(&condition) == Sort::Bool {
Ok(self.stack.push(StateEntry::Cond(condition)))
self.stack.push(StateEntry::Cond(condition));
Ok(())
} else {
Err(CircError::NotBool(condition))
}
@@ -313,7 +314,7 @@ impl<Ty: Display> FnFrame<Ty> {
StateEntry::Cond(c) => break_if.push(c.clone()),
StateEntry::Break(name_, ref mut break_conds) => {
if name_ == name {
break_conds.push(if break_if.len() == 0 {
break_conds.push(if break_if.is_empty() {
leaf_term(Op::Const(Value::Bool(true)))
} else {
term(Op::BoolNaryOp(BoolNaryOp::And), break_if)
@@ -394,6 +395,8 @@ pub trait Embeddable {
/// Create a new term for the default return value of a function returning type `ty`.
/// The name `ssa_name` is globally unique, and can be used if needed.
// Because the type alias may change.
#[allow(clippy::ptr_arg)]
fn initialize_return(&self, ty: &Self::Ty, ssa_name: &SsaName) -> Self::T;
}
@@ -450,9 +453,9 @@ impl<E: Embeddable> Circify<E> {
/// Initialize environment entry binding `name` to `ty`.
fn declare_env_name(&mut self, name: VarName, ty: &E::Ty) -> Result<&SsaName> {
if let Some(back) = self.fn_stack.last_mut() {
back.declare(name.clone(), ty.clone())
back.declare(name, ty.clone())
} else {
self.globals.declare(name.clone(), ty.clone())
self.globals.declare(name, ty.clone())
}
}
@@ -484,19 +487,18 @@ impl<E: Embeddable> Circify<E> {
/// Returns `None` if it's in the global scope.
/// Returns `Some` if it's in a local scope.
/// Errors if the name cannot be found.
// Because the type alias may change.
#[allow(clippy::ptr_arg)]
fn mk_abs(&self, name: &VarName) -> Result<Option<ScopeIdx>> {
if let Some(fn_) = self.fn_stack.last() {
for (lex_i, e) in fn_.stack.iter().enumerate().rev() {
match e {
StateEntry::Lex(l) => {
if l.has_name(name) {
return Ok(Some(ScopeIdx {
lex: lex_i,
fn_: self.fn_stack.len() - 1,
}));
}
if let StateEntry::Lex(l) = e {
if l.has_name(name) {
return Ok(Some(ScopeIdx {
lex: lex_i,
fn_: self.fn_stack.len() - 1,
}));
}
_ => {}
}
}
}
@@ -551,7 +553,7 @@ impl<E: Embeddable> Circify<E> {
///
/// If `public`, then make it a public (fixed) rather than private (existential) circuit input.
pub fn declare_init(&mut self, name: VarName, ty: E::Ty, val: Val<E::T>) -> Result<Val<E::T>> {
let ssa_name = self.declare_env_name(name.clone(), &ty)?.clone();
let ssa_name = self.declare_env_name(name, &ty)?.clone();
// TODO: add language-specific coersion here if needed
assert!(self.vals.insert(ssa_name, val.clone()).is_none());
Ok(val)
@@ -566,7 +568,7 @@ impl<E: Embeddable> Circify<E> {
ty: &E::Ty,
visiblity: Option<PartyId>,
) {
self.e.assign(&mut self.cir_ctx, &ty, name, term, visiblity);
self.e.assign(&mut self.cir_ctx, ty, name, term, visiblity);
}
/// Assign `loc` in the current scope to `val`.
@@ -669,7 +671,7 @@ impl<E: Embeddable> Circify<E> {
pub fn condition(&self) -> Term {
// TODO: more precise conditions, depending on lex scopes.
let cs: Vec<_> = self.fn_stack.iter().flat_map(|f| f.conditions()).collect();
if cs.len() == 0 {
if cs.is_empty() {
leaf_term(Op::Const(Value::Bool(true)))
} else {
term(Op::BoolNaryOp(BoolNaryOp::And), cs)
@@ -748,7 +750,7 @@ impl<E: Embeddable> Circify<E> {
self.vals
.get(l.get_name(&loc.name)?)
.cloned()
.ok_or_else(|| CircError::InvalidLoc(loc))
.ok_or(CircError::InvalidLoc(loc))
}
/// Dereference a reference into a location.
@@ -762,6 +764,8 @@ impl<E: Embeddable> Circify<E> {
}
/// Create a reference to a variable.
// Because the type alias may change.
#[allow(clippy::ptr_arg)]
pub fn ref_(&self, name: &VarName) -> Result<Val<E::T>> {
let idx = self.mk_abs(name)?;
Ok(Val::Ref(Loc {

View File

@@ -56,7 +56,7 @@ impl<'ast> Gen<'ast> {
/// Attempt to enter a funciton.
/// Returns `false` if doing so would violate the recursion limit.
fn enter_function(&mut self, name: &'ast str, dec_value: Option<Integer>) -> bool {
let e = self.stack_by_fn.entry(name).or_insert_with(|| Vec::new());
let e = self.stack_by_fn.entry(name).or_insert_with(Vec::new);
//assert_eq!(e.last().and_then(|l| l.as_ref()).is_some(), dec_value.is_some());
let do_enter = if let (Some(last_val), Some(this_val)) =
(e.last().and_then(|l| l.as_ref()), dec_value.as_ref())
@@ -86,7 +86,7 @@ impl<'ast> Gen<'ast> {
fn register_rules(&mut self, pgm: &'ast ast::Program<'ast>) {
for r in &pgm.rules {
assert!(!self.rules.contains_key(&r.name.value));
self.rules.insert(&r.name.value, r);
self.rules.insert(r.name.value, r);
}
}
@@ -102,7 +102,7 @@ impl<'ast> Gen<'ast> {
}
},
|t, size| {
let size = usize::from_str(&size.value).expect("bad array size");
let size = usize::from_str(size.value).expect("bad array size");
ty::Ty::Array(size, Box::new(t))
},
),
@@ -127,7 +127,7 @@ impl<'ast> Gen<'ast> {
let vis = if public { PUBLIC_VIS } else { PROVER_VIS };
self.circ.declare(d.ident.value.into(), &ty, public, vis)?;
}
let r = self.rule_cases(&rule)?;
let r = self.rule_cases(rule)?;
self.exit_function(name);
self.circ.assert(r.as_bool());
Ok(())
@@ -165,19 +165,19 @@ impl<'ast> Gen<'ast> {
/// * `top_level` indicates whether this expression is a top-level expression in a condition.
fn expr(&mut self, e: &'ast ast::Expression, top_level: bool) -> Result<'ast, term::T> {
match e {
&ast::Expression::Binary(ref b) => self.bin_expr(b),
&ast::Expression::Unary(ref u) => self.un_expr(u),
&ast::Expression::Paren(ref i, _) => self.expr(i, top_level),
&ast::Expression::Identifier(ref i) => self.ident(i),
&ast::Expression::Literal(ref i) => self.literal(i),
&ast::Expression::Access(ref c) => {
ast::Expression::Binary(ref b) => self.bin_expr(b),
ast::Expression::Unary(ref u) => self.un_expr(u),
ast::Expression::Paren(ref i, _) => self.expr(i, top_level),
ast::Expression::Identifier(ref i) => self.ident(i),
ast::Expression::Literal(ref i) => self.literal(i),
ast::Expression::Access(ref c) => {
let arr = self.ident(&c.arr)?;
c.idxs.iter().try_fold(arr, |arr, idx| {
let idx_v = self.expr(idx, false)?;
term::array_idx(&arr, &idx_v).map_err(|err| Error::new(err, idx.span().clone()))
})
}
&ast::Expression::Call(ref c) => {
ast::Expression::Call(ref c) => {
let args = c
.args
.iter()
@@ -202,8 +202,7 @@ impl<'ast> Gen<'ast> {
.args
.iter()
.enumerate()
.filter(|&(_, arg)| arg.dec.is_some())
.next()
.find(|&(_, arg)| arg.dec.is_some())
{
let ir = &args[i].ir;
let reduced_ir = fold(ir);
@@ -225,7 +224,7 @@ impl<'ast> Gen<'ast> {
)
.unwrap();
}
let r = self.rule_cases(&rule)?;
let r = self.rule_cases(rule)?;
self.exit_function(name);
Ok(r)
} else {
@@ -238,22 +237,22 @@ impl<'ast> Gen<'ast> {
}
fn literal(&mut self, e: &ast::Literal) -> Result<'ast, term::T> {
match e {
&ast::Literal::BinLiteral(ref b) => {
ast::Literal::BinLiteral(ref b) => {
let len = b.value.len() as u8 - 2;
let val = u64::from_str_radix(&b.value[2..], 2).unwrap();
Ok(term::uint_lit(val, len))
}
&ast::Literal::HexLiteral(ref b) => {
ast::Literal::HexLiteral(ref b) => {
let len = (b.value.len() as u8 - 2) * 4;
let val = u64::from_str_radix(&b.value[2..], 16).unwrap();
Ok(term::uint_lit(val, len))
}
&ast::Literal::DecimalLiteral(ref b) => {
let val = Integer::from_str(&b.value).unwrap();
ast::Literal::DecimalLiteral(ref b) => {
let val = Integer::from_str(b.value).unwrap();
Ok(term::pf_lit(val))
}
&ast::Literal::BooleanLiteral(ref b) => {
let val = bool::from_str(&b.value).unwrap();
ast::Literal::BooleanLiteral(ref b) => {
let val = bool::from_str(b.value).unwrap();
Ok(term::bool_lit(val))
}
}
@@ -310,7 +309,7 @@ impl<'ast> Gen<'ast> {
.enumerate()
.find(|(_, arg)| arg.dec.is_some())
{
self.enter_function(&rule.name.value, None);
self.enter_function(rule.name.value, None);
for d in &rule.args {
let (ty, public) = self.ty(&d.ty);
let vis = if public { PUBLIC_VIS } else { PROVER_VIS };
@@ -330,7 +329,7 @@ impl<'ast> Gen<'ast> {
let mut bad_recursion = Vec::new();
for atom in &cond.exprs {
if let ast::Expression::Call(c) = &atom {
if &c.fn_name.value == &rule.name.value {
if c.fn_name.value == rule.name.value {
let formal_arg = self
.circ
.get_value(Loc::local(rule.args[arg_idx].ident.value.to_owned()))?
@@ -362,7 +361,7 @@ impl<'ast> Gen<'ast> {
)?);
self.circ.exit_scope();
}
self.exit_function(&rule.name.value);
self.exit_function(rule.name.value);
bug_in_rule_if_any
.into_iter()
.try_fold(term::bool_lit(false), |x, y| {

View File

@@ -229,7 +229,7 @@ pub mod ast {
match self {
Expression::Binary(b) => &b.span,
Expression::Identifier(i) => &i.span,
Expression::Literal(c) => &c.span(),
Expression::Literal(c) => c.span(),
Expression::Unary(u) => &u.span,
Expression::Call(u) => &u.span,
Expression::Access(u) => &u.span,
@@ -530,7 +530,7 @@ pub mod ast {
}
pub fn parse(file_string: &str) -> Result<ast::Program, Error<Rule>> {
let mut pest_pairs = MyParser::parse(Rule::program, &file_string)?;
let mut pest_pairs = MyParser::parse(Rule::program, file_string)?;
use from_pest::FromPest;
Ok(ast::Program::from_pest(&mut pest_pairs).expect("bug in AST construction"))
}

View File

@@ -25,7 +25,7 @@ impl T {
/// Create a new term, checking that the explicit type and IR type agree.
pub fn new(ir: Term, ty: Ty) -> Self {
let ir_ty = check(&ir);
let res = Self { ir, ty: ty.clone() };
let res = Self { ir, ty };
Self::check_ty(&ir_ty, &res.ty);
res
}
@@ -43,7 +43,7 @@ impl T {
#[track_caller]
pub fn as_bool(&self) -> Term {
match &self.ty {
&Ty::Bool => self.ir.clone(),
Ty::Bool => self.ir.clone(),
_ => panic!("{} is not a bool", self),
}
}
@@ -395,7 +395,7 @@ impl Embeddable for Datalog {
&*inner_ty,
idx_name(&raw_name, i),
user_name.as_ref().map(|u| idx_name(u, i)),
visibility.clone(),
visibility,
)
})
.enumerate()
@@ -409,10 +409,7 @@ impl Embeddable for Datalog {
}
fn ite(&self, _ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
if t.ty == f.ty {
T::new(
term![Op::Ite; cond.clone(), t.ir.clone(), f.ir.clone()],
t.ty.clone(),
)
T::new(term![Op::Ite; cond, t.ir, f.ir], t.ty)
} else {
panic!("Cannot ITE {} and {}", t, f)
}
@@ -444,6 +441,12 @@ impl Embeddable for Datalog {
}
}
impl Default for Datalog {
fn default() -> Self {
Self::new()
}
}
impl Datalog {
/// Initialize the Datalog lang def
pub fn new() -> Self {

View File

@@ -18,10 +18,10 @@ pub enum Ty {
impl Display for Ty {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match &self {
&Ty::Bool => write!(f, "bool"),
&Ty::Field => write!(f, "field"),
&Ty::Uint(w) => write!(f, "u{}", w),
&Ty::Array(l, t) => write!(f, "{}[{}]", t, l),
Ty::Bool => write!(f, "bool"),
Ty::Field => write!(f, "field"),
Ty::Uint(w) => write!(f, "u{}", w),
Ty::Array(l, t) => write!(f, "{}[{}]", t, l),
}
}
}

View File

@@ -1,6 +1,7 @@
//! Input language front-ends
pub mod datalog;
#[allow(clippy::all)]
pub mod zokrates;
use super::ir::term::Computation;

View File

@@ -42,7 +42,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
stack.extend(t.cs.iter().map(|c| (c.clone(), false)));
continue;
}
let c_get = |x: &Term| -> Term { cache.get(&x).expect("postorder cache").clone() };
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
let get = |i: usize| c_get(&t.cs[i]);
let new_t_opt = match &t.op {
&NOT => get(0).as_bool_opt().and_then(|c| cbool(!c)),
@@ -59,7 +59,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
Some(bv) => cbool(bv.bit(*i)),
_ => None,
},
Op::BoolNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
Op::BoolNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
Op::Eq => {
let c0 = get(0);
let c1 = get(1);
@@ -119,7 +119,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
_ => None,
}
}
Op::BvNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
Op::BvNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
Op::BvBinPred(p) => {
if let (Some(a), Some(b)) = (get(0).as_bv_opt(), get(1).as_bv_opt()) {
Some(leaf_term(Op::Const(Value::Bool(match p {
@@ -168,7 +168,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
},
}
}
Op::PfNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
Op::PfNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
Op::PfUnOp(o) => get(0).as_pf_opt().map(|pf| {
leaf_term(Op::Const(Value::Field(match o {
PfUnOp::Recip => pf.clone().recip(),
@@ -177,17 +177,17 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
}),
_ => None,
};
let c_get = |x: &Term| -> Term { cache.get(&x).expect("postorder cache").clone() };
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
let new_t = new_t_opt
.unwrap_or_else(|| term(t.op.clone(), t.cs.iter().map(|c| c_get(c)).collect()));
cache.insert(t, new_t);
}
cache.get(&node).expect("postorder cache").clone()
cache.get(node).expect("postorder cache").clone()
}
fn neg_bool(t: Term) -> Term {
match &t.op {
&NOT => t.cs[0].clone(),
match t.op {
NOT => t.cs[0].clone(),
_ => term![NOT; t],
}
}
@@ -220,7 +220,7 @@ impl NaryFlat<bool> for BoolNaryOp {
BoolNaryOp::Or => {
if consts.iter().any(|b| *b) {
leaf_term(Op::Const(Value::Bool(true)))
} else if children.len() == 0 {
} else if children.is_empty() {
leaf_term(Op::Const(Value::Bool(false)))
} else {
safe_nary(OR, children)
@@ -229,7 +229,7 @@ impl NaryFlat<bool> for BoolNaryOp {
BoolNaryOp::And => {
if consts.iter().any(|b| !*b) {
leaf_term(Op::Const(Value::Bool(false)))
} else if children.len() == 0 {
} else if children.is_empty() {
leaf_term(Op::Const(Value::Bool(true)))
} else {
safe_nary(AND, children)
@@ -237,7 +237,7 @@ impl NaryFlat<bool> for BoolNaryOp {
}
BoolNaryOp::Xor => {
let odd_trues = consts.into_iter().filter(|b| *b).count() % 2 == 1;
if children.len() == 0 {
if children.is_empty() {
leaf_term(Op::Const(Value::Bool(odd_trues)))
} else {
let t = safe_nary(XOR, children);
@@ -264,7 +264,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
BvNaryOp::Or => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::BitOr::bitor);
if children.len() == 0 {
if children.is_empty() {
leaf_term(Op::Const(Value::BitVector(c)))
} else if c.uint() == &Integer::from(0) {
safe_nary(BV_OR, children)
@@ -297,7 +297,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
BvNaryOp::And => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::BitAnd::bitand);
if children.len() == 0 {
if children.is_empty() {
leaf_term(Op::Const(Value::BitVector(c)))
} else {
safe_nary(
@@ -328,7 +328,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
BvNaryOp::Xor => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::BitXor::bitxor);
if children.len() == 0 {
if children.is_empty() {
leaf_term(Op::Const(Value::BitVector(c)))
} else {
safe_nary(
@@ -360,7 +360,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
BvNaryOp::Add => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::Add::add);
if c.uint() != &Integer::from(0) || children.len() == 0 {
if c.uint() != &Integer::from(0) || children.is_empty() {
children.push(leaf_term(Op::Const(Value::BitVector(c))));
}
}
@@ -372,7 +372,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
if c.uint() == &Integer::from(0) {
leaf_term(Op::Const(Value::BitVector(c)))
} else {
if c.uint() != &Integer::from(1) || children.len() == 0 {
if c.uint() != &Integer::from(1) || children.is_empty() {
children.push(leaf_term(Op::Const(Value::BitVector(c))));
}
safe_nary(BV_MUL, children)
@@ -397,7 +397,7 @@ impl NaryFlat<FieldElem> for PfNaryOp {
PfNaryOp::Add => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::Add::add);
if c.i() != &Integer::from(0) || children.len() == 0 {
if c.i() != &Integer::from(0) || children.is_empty() {
children.push(leaf_term(Op::Const(Value::Field(c))));
}
}
@@ -406,7 +406,7 @@ impl NaryFlat<FieldElem> for PfNaryOp {
PfNaryOp::Mul => {
if let Some(c) = consts.pop() {
let c = consts.into_iter().fold(c, std::ops::Mul::mul);
if c.i() == &Integer::from(0) || children.len() == 0 {
if c.i() == &Integer::from(0) || children.is_empty() {
leaf_term(Op::Const(Value::Field(c)))
} else {
if c.i() != &Integer::from(1) {

View File

@@ -23,7 +23,7 @@ impl<T: Clone> PersistentConcatList<T> {
match &*t {
PersistentConcatList::Leaf(t) => v.push((**t).clone()),
PersistentConcatList::Concat(ts) => {
ts.into_iter().for_each(|c| stack.push(c.clone()));
ts.iter().for_each(|c| stack.push(c.clone()));
}
}
}
@@ -32,7 +32,7 @@ impl<T: Clone> PersistentConcatList<T> {
}
impl Entry {
fn to_term(&mut self) -> Term {
fn as_term(&mut self) -> Term {
match self {
Entry::Term(t) => (**t).clone(),
Entry::NaryTerm(o, ts, maybe_term) => {
@@ -47,6 +47,7 @@ impl Entry {
}
/// Flattening cache.
#[derive(Default)]
pub struct Cache(TermMap<Entry>);
impl Cache {
@@ -99,7 +100,7 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache
}
e => {
children
.push(Rc::new(PersistentConcatList::Leaf(Rc::new(e.to_term()))));
.push(Rc::new(PersistentConcatList::Leaf(Rc::new(e.as_term()))));
}
}
}
@@ -112,13 +113,13 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache
_ => Entry::Term(Rc::new(term(
t.op.clone(),
t.cs.iter()
.map(|c| rewritten.get_mut(c).unwrap().to_term())
.map(|c| rewritten.get_mut(c).unwrap().as_term())
.collect(),
))),
};
rewritten.insert(t, entry);
}
rewritten.get_mut(&term_).unwrap().to_term()
rewritten.get_mut(&term_).unwrap().as_term()
}
#[cfg(test)]

View File

@@ -63,7 +63,7 @@ impl<'a> Inliner<'a> {
}
}
assert!(
self.stale_vars.contains(&key),
self.stale_vars.contains(key),
"Variable {}, which is being susbstituted",
key
);
@@ -107,7 +107,7 @@ impl<'a> Inliner<'a> {
///
/// Will not return `v` which are protected.
fn as_fresh_def(&self, t: &Term) -> Option<(Term, Term)> {
if &EQ == &t.op {
if EQ == t.op {
if let Op::Var(name, _) = &t.cs[0].op {
if !self.stale_vars.contains(&t.cs[0])
&& !self.protected.contains(name)
@@ -133,7 +133,7 @@ impl<'a> Inliner<'a> {
///
/// If `t` is not a substitution, then its (substituted variant) is returned.
fn ingest_term(&mut self, t: &Term) -> Option<Term> {
if let Some((var, val)) = self.as_fresh_def(&t) {
if let Some((var, val)) = self.as_fresh_def(t) {
//debug!(target: "circ::ir::opt::inline", "Inline: {} -> {}", var, val.clone());
// Rewrite the substitution
let subst_val = self.apply(&val);

View File

@@ -223,7 +223,7 @@ impl RewritePass for Replacer {
}
}
Op::Store => {
if self.should_replace(&orig) {
if self.should_replace(orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 3);
let k_const = get_const(&cs.remove(1));
@@ -233,7 +233,7 @@ impl RewritePass for Replacer {
}
}
Op::Ite => {
if self.should_replace(&orig) {
if self.should_replace(orig) {
Some(term(Op::Ite, get_cs()))
} else {
None

View File

@@ -64,7 +64,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
let mut new_outputs = Vec::new();
for a in std::mem::take(&mut cs.outputs) {
assert_eq!(check(&a), Sort::Bool, "Non-bool in {:?}", i);
if &a.op == &Op::BoolNaryOp(BoolNaryOp::And) {
if a.op == Op::BoolNaryOp(BoolNaryOp::And) {
new_outputs.extend(a.cs.iter().cloned());
} else {
new_outputs.push(a)

View File

@@ -95,7 +95,7 @@ impl RewritePass for Pass {
party_visibility,
&mut new_var_reqs,
);
if new_var_reqs.len() > 0 {
if !new_var_reqs.is_empty() {
computation.replace_input(orig.clone(), new_var_reqs);
}
Some(new)

View File

@@ -18,8 +18,8 @@ pub fn sha_rewrites(term_: &Term) -> Term {
if t.cs.len() == 2 {
let a = get(0);
let b = get(1);
if &a.op == &b.op
&& &a.op == &BV_AND
if a.op == b.op
&& a.op == BV_AND
&& b.cs[0].op == BV_NOT
&& b.cs[0].cs[0] == a.cs[0]
{
@@ -46,9 +46,9 @@ pub fn sha_rewrites(term_: &Term) -> Term {
let c0 = get(0);
let c1 = get(1);
let c2 = get(2);
if &c0.op == &c1.op
&& &c1.op == &c2.op
&& &c2.op == &BV_AND
if c0.op == c1.op
&& c1.op == c2.op
&& c2.op == BV_AND
&& c0.cs.len() == 2
&& c1.cs.len() == 2
&& c2.cs.len() == 2
@@ -62,7 +62,7 @@ pub fn sha_rewrites(term_: &Term) -> Term {
{
debug!("SHA MAJ");
let items = s0.union(&s1).collect::<Vec<_>>();
let w = check(&c0).as_bv();
let w = check(c0).as_bv();
Some(term(
BV_CONCAT,
(0..w)
@@ -95,7 +95,7 @@ pub fn sha_rewrites(term_: &Term) -> Term {
});
cache.insert(t, new_t);
}
cache.get(&term_).unwrap().clone()
cache.get(term_).unwrap().clone()
}
/// Eliminate the SHA majority operator, replacing it with ands and ors.
@@ -105,9 +105,9 @@ pub fn sha_maj_elim(term_: &Term) -> Term {
for t in PostOrderIter::new(term_.clone()) {
let c_get = |x: &Term| cache.get(x).unwrap();
let get = |i: usize| c_get(&t.cs[i]);
let new_t = match &t.op {
let new_t = match t.op {
// maj(a, b, c) = (a & b) | (b & c) | (c & a)
&Op::BoolMaj => {
Op::BoolMaj => {
let a = get(0);
let b = get(1);
let c = get(2);
@@ -126,7 +126,7 @@ pub fn sha_maj_elim(term_: &Term) -> Term {
});
cache.insert(t, new_t);
}
cache.get(&term_).unwrap().clone()
cache.get(term_).unwrap().clone()
}
#[cfg(test)]

View File

@@ -68,7 +68,7 @@ impl TupleTree {
fn flatten(&self) -> impl Iterator<Item = Term> {
let mut out = Vec::new();
fn rec_unroll_into(t: &Term, out: &mut Vec<Term>) {
if &t.op == &Op::Tuple {
if t.op == Op::Tuple {
for c in &t.cs {
rec_unroll_into(c, out);
}
@@ -81,7 +81,7 @@ impl TupleTree {
}
fn structure(&self, flattened: impl IntoIterator<Item = Term>) -> Self {
fn term_structure(t: &Term, iter: &mut impl Iterator<Item = Term>) -> Term {
if &t.op == &Op::Tuple {
if t.op == Op::Tuple {
term(
Op::Tuple,
t.cs.iter().map(|c| term_structure(c, iter)).collect(),
@@ -94,9 +94,9 @@ impl TupleTree {
}
fn well_formed(&self) -> bool {
for t in PostOrderIter::new(self.0.clone()) {
if &t.op != &Op::Tuple {
if t.op != Op::Tuple {
for c in &t.cs {
if &c.op == &Op::Tuple {
if c.op == Op::Tuple {
return false;
}
}
@@ -267,11 +267,9 @@ impl RewritePass for TupleLifter {
}
}
#[allow(dead_code)]
fn tuple_free(t: Term) -> bool {
!PostOrderIter::new(t).any(|c| match check(&c) {
Sort::Tuple(_) => true,
_ => false,
})
PostOrderIter::new(t).all(|c| !matches!(check(&c), Sort::Tuple(..)))
}
/// Run the tuple elimination pass.

View File

@@ -45,11 +45,8 @@ impl Constraints for Computation {
let mut set = FxHashSet::default();
for a in &assertions {
for t in PostOrderIter::new(a.clone()) {
match &t.op {
Op::Var(_, _) => {
set.insert(t.clone());
}
_ => {}
if let Op::Var(..) = t.op {
set.insert(t.clone());
}
}
}

View File

@@ -122,8 +122,8 @@ impl FixedSizeDist {
Op::Var(self.sample_ident(&format!("bv{}", w), rng), sort.clone()),
Op::BvUnOp(BvUnOp::Neg),
Op::BvUnOp(BvUnOp::Not),
Op::BvUext(rng.gen_range(0..w.clone())),
Op::BvSext(rng.gen_range(0..w.clone())),
Op::BvUext(rng.gen_range(0..*w)),
Op::BvSext(rng.gen_range(0..*w)),
Op::BvBinOp(BvBinOp::Sub),
Op::BvBinOp(BvBinOp::Udiv),
Op::BvBinOp(BvBinOp::Urem),
@@ -292,7 +292,7 @@ impl rand::distributions::Distribution<Term> for FixedSizeDist {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
let op = self.sample_op(&self.sort, rng);
let sorts = self.sample_child_sorts(&self.sort, &op, rng);
if sorts.len() == 0 {
if sorts.is_empty() {
leaf_term(op)
} else {
let mut dists: Vec<FixedSizeDist> = sorts

View File

@@ -1,17 +1,16 @@
//! Extra algorithms over terms (e.g. substitutions)
use super::*;
use std::cmp::Ordering;
use std::fmt::{self, Display, Formatter};
/// Convert `t` to width `w`, though unsigned extension or extraction
pub fn to_width(t: &Term, w: usize) -> Term {
let old_w = check(t).as_bv();
if old_w < w {
term(Op::BvUext(w - old_w), vec![t.clone()])
} else if old_w == w {
t.clone()
} else {
term(Op::BvExtract(w - 1, 0), vec![t.clone()])
match old_w.cmp(&w) {
Ordering::Less => term(Op::BvUext(w - old_w), vec![t.clone()]),
Ordering::Equal => t.clone(),
Ordering::Greater => term(Op::BvExtract(w - 1, 0), vec![t.clone()]),
}
}
@@ -45,7 +44,7 @@ impl Display for Letified {
writeln!(f, "(let (")?;
for t in PostOrderIter::new(self.0.clone()) {
if parent_counts.get(&t).unwrap_or(&0) > &1 && t.cs.len() > 0 {
if parent_counts.get(&t).unwrap_or(&0) > &1 && !t.cs.is_empty() {
let name = format!("let_{}", let_ct);
let_ct += 1;
let sort = check(&t);
@@ -131,7 +130,7 @@ pub fn free_in(v: &str, t: Term) -> bool {
_ => {}
}
}
return false;
false
}
/// If this term is a constant field or bit-vector, get the unsigned int value.

View File

@@ -37,7 +37,7 @@ impl FieldElem {
location
);
debug_assert!(
&self.i <= &*self.modulus,
self.i <= *self.modulus,
"Too small field elem: {}\n at {}",
self,
location

View File

@@ -708,11 +708,17 @@ impl Display for Array {
}
impl std::cmp::Eq for Value {}
// We walk in danger here, intentionally. One day we may fix it.
// FP is the heart of the problem.
#[allow(clippy::derive_ord_xor_partial_ord)]
impl std::cmp::Ord for Value {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
self.partial_cmp(o).expect("broken Value cmp")
}
}
// We walk in danger here, intentionally. One day we may fix it.
// FP is the heart of the problem.
#[allow(clippy::derive_hash_xor_eq)]
impl std::hash::Hash for Value {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
@@ -778,7 +784,7 @@ impl Sort {
/// Unwrap the constituent sorts of this tuple, panicking otherwise.
pub fn as_tuple(&self) -> &Vec<Sort> {
if let Sort::Tuple(w) = self {
&w
w
} else {
panic!("{} is not a tuple", self)
}
@@ -815,7 +821,7 @@ impl Sort {
Box::new(
std::iter::successors(Some(Integer::from(0)), move |p| {
let q = p.clone() + 1;
if &q < &*m {
if q < *m {
Some(q)
} else {
None
@@ -927,21 +933,18 @@ impl TermTable {
if val.elm.upgrade().is_some() {
true
} else {
to_check.extend(key.cs.iter().map(|i| i.clone()));
to_check.extend(key.cs.iter().cloned());
false
}
});
while let Some(t) = to_check.pop() {
let data: TermData = (*t).clone();
std::mem::drop(t);
match self.map.entry(data) {
std::collections::hash_map::Entry::Occupied(e) => {
if e.get().elm.upgrade().is_none() {
let (key, _val) = e.remove_entry();
to_check.extend(key.cs.iter().map(|i| i.clone()));
}
if let std::collections::hash_map::Entry::Occupied(e) = self.map.entry(data) {
if e.get().elm.upgrade().is_none() {
let (key, _val) = e.remove_entry();
to_check.extend(key.cs.iter().cloned());
}
_ => {}
}
}
let new_size = self.map.len();
@@ -1009,19 +1012,11 @@ impl TermData {
}
/// Is this a variable?
pub fn is_var(&self) -> bool {
if let Op::Var(..) = &self.op {
true
} else {
false
}
matches!(&self.op, Op::Var(..))
}
/// Is this a value
pub fn is_const(&self) -> bool {
if let Op::Const(..) = &self.op {
true
} else {
false
}
matches!(&self.op, Op::Const(..))
}
}
@@ -1085,7 +1080,7 @@ impl Value {
/// Unwrap the constituent value of this array, panicking otherwise.
pub fn as_array(&self) -> &Array {
if let Value::Array(w) = self {
&w
w
} else {
panic!("{} is not an aray", self)
}
@@ -1377,7 +1372,7 @@ impl std::iter::Iterator for PostOrderIter {
type Item = Term;
fn next(&mut self) -> Option<Term> {
while let Some((children_pushed, t)) = self.stack.last() {
if self.visited.contains(&t) {
if self.visited.contains(t) {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
@@ -1428,7 +1423,7 @@ impl ComputationMetadata {
party,
self.input_vis.get(&input_name).unwrap()
);
self.input_vis.insert(input_name.clone(), party);
self.input_vis.insert(input_name, party);
self.inputs.push(term);
}
/// Replace the `original` computation input with `new`, in the order given.
@@ -1473,10 +1468,10 @@ impl ComputationMetadata {
}
/// Returns None if the value is public. Otherwise, the unique party that knows it.
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
self.input_vis
*self
.input_vis
.get(input_name)
.unwrap_or_else(|| panic!("Missing input {} in inputs{:#?}", input_name, self.inputs))
.clone()
}
/// Is this input public?
pub fn is_input(&self, input_name: &str) -> bool {
@@ -1497,6 +1492,9 @@ impl ComputationMetadata {
})
}
/// Get all public inputs.
// I think the lint is just broken here.
// TODO: submit a patch
#[allow(clippy::needless_lifetimes)]
pub fn public_inputs<'a>(&'a self) -> impl Iterator<Item = Term> + 'a {
self.inputs.iter().filter_map(move |input| {
if let Op::Var(name, _) = &input.op {
@@ -1513,7 +1511,7 @@ impl ComputationMetadata {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
/// An IR computation.
pub struct Computation {
/// The outputs of the computation.
@@ -1526,16 +1524,6 @@ pub struct Computation {
pub metadata: ComputationMetadata,
}
impl std::default::Default for Computation {
fn default() -> Self {
Self {
outputs: Vec::new(),
metadata: ComputationMetadata::default(),
values: None,
}
}
}
impl Computation {
/// Create a new variable, `name: s`, where `val_fn` can be called to get the concrete value,
/// and `public` indicates whether this variable is public in the constraint system.

View File

@@ -41,7 +41,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
Op::BvConcat => t
.cs
.iter()
.map(|c| check_raw(c))
.map(check_raw)
.try_fold(
Ok(0),
|l: Result<usize, TypeErrorReason>,
@@ -88,9 +88,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
Op::Select => array_or(&check_raw(&t.cs[0])?, "select").map(|(_, v)| v.clone()),
Op::Store => Ok(check_raw(&t.cs[0])?),
Op::Tuple => Ok(Sort::Tuple(
t.cs.iter()
.map(|c| check_raw(c))
.collect::<Result<Vec<_>, _>>()?,
t.cs.iter().map(check_raw).collect::<Result<Vec<_>, _>>()?,
)),
Op::Field(i) => {
let sort = check_raw(&t.cs[0])?;
@@ -127,20 +125,17 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while to_check.len() > 0 {
while !to_check.is_empty() {
let back = to_check.last_mut().unwrap();
let weak = back.0.to_weak();
// The idea here is to check that
match term_tys.get_key_value(&weak) {
Some((p, _)) => {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
term_tys.remove(&weak);
}
if let Some((p, _)) = term_tys.get_key_value(&weak) {
if p.to_hconsed().is_some() {
to_check.pop();
continue;
} else {
term_tys.remove(&weak);
}
None => {}
}
if !back.1 {
back.1 = true;
@@ -173,7 +168,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
}
(Op::BvNaryOp(_), a) => {
let ctx = "bv nary op";
all_eq_or(a.into_iter().cloned(), ctx)
all_eq_or(a.iter().cloned(), ctx)
.and_then(|t| bv_or(t, ctx))
.map(|a| a.clone())
}
@@ -207,7 +202,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
}
(Op::BoolNaryOp(_), a) => {
let ctx = "bool nary op";
all_eq_or(a.into_iter().cloned(), ctx)
all_eq_or(a.iter().cloned(), ctx)
.and_then(|t| bool_or(t, ctx))
.map(|a| a.clone())
}
@@ -252,7 +247,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
(Op::FpToFp(32), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F32),
(Op::PfNaryOp(_), a) => {
let ctx = "pf nary op";
all_eq_or(a.into_iter().cloned(), ctx)
all_eq_or(a.iter().cloned(), ctx)
.and_then(|t| pf_or(t, ctx))
.map(|a| a.clone())
}
@@ -264,9 +259,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
(Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store")
.and_then(|_| eq_or(v, b, "store"))
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
(Op::Tuple, a) => {
Ok(Sort::Tuple(a.into_iter().map(|a| (*a).clone()).collect()))
}
(Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())),
(Op::Field(i), &[a]) => tuple_or(a, "tuple field access").and_then(|t| {
if i < &t.len() {
Ok(t[*i].clone())
@@ -288,7 +281,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
)))
}
}),
(_, _) => Err(TypeErrorReason::Custom(format!("other"))),
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
})
.map_err(|reason| TypeError {
op: back.0.op.clone(),
@@ -357,7 +350,7 @@ fn array_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<(&'a Sort, &'a Sort),
}
fn bool_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
if let &Sort::Bool = a {
if let Sort::Bool = a {
Ok(a)
} else {
Err(TypeErrorReason::ExpectedBool(a.clone(), ctx))

View File

@@ -3,6 +3,7 @@
//! A compiler infrastructure for compiling programs to circuits
#![warn(missing_docs)]
#![deny(warnings)]
#[macro_use]
pub mod ir;

View File

@@ -109,7 +109,7 @@ impl CostModel {
};
for (op_name, json) in obj {
// HACK: assumes the presence of 2 partitions names into conversion and otherwise.
if !op_name.contains("2") {
if !op_name.contains('2') {
for op in ops_from_name(op_name) {
let obj = json.as_object().unwrap();
for (share_type, share_name) in
@@ -117,7 +117,7 @@ impl CostModel {
{
if let Some(cost) = get_cost_opt(share_name, obj) {
ops.entry(op.clone())
.or_insert_with(|| FxHashMap::default())
.or_insert_with(FxHashMap::default)
.insert(*share_type, cost);
}
}
@@ -160,29 +160,27 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
// build variables for all term assignments
for (t, i) in terms.iter() {
let mut vars = vec![];
if let Op::Var(_, _) = &t.op {
for ty in &SHARE_TYPES {
let name = format!("t_{}_{}", i, ty.char());
let v = ilp.new_variable(variable().binary(), name.clone());
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
vars.push(v);
match &t.op {
Op::Var(..) | Op::Const(_) => {
for ty in &SHARE_TYPES {
let name = format!("t_{}_{}", i, ty.char());
let v = ilp.new_variable(variable().binary(), name.clone());
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
vars.push(v);
}
}
} else if let Op::Const(_) = &t.op {
for ty in &SHARE_TYPES {
let name = format!("t_{}_{}", i, ty.char());
let v = ilp.new_variable(variable().binary(), name.clone());
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
vars.push(v);
_ => {
if let Some(costs) = costs.ops.get(&t.op) {
for (ty, cost) in costs {
let name = format!("t_{}_{}", i, ty.char());
let v = ilp.new_variable(variable().binary(), name.clone());
term_vars.insert((t.clone(), *ty), (v, *cost, name));
vars.push(v);
}
} else {
panic!("No cost for op {}", &t.op)
}
}
} else if let Some(costs) = costs.ops.get(&t.op) {
for (ty, cost) in costs {
let name = format!("t_{}_{}", i, ty.char());
let v = ilp.new_variable(variable().binary(), name.clone());
term_vars.insert((t.clone(), *ty), (v, *cost, name));
vars.push(v);
}
} else {
panic!("No cost for op {}", &t.op)
}
// Sum of assignments is at least 1.
ilp.new_constraint(
@@ -218,7 +216,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
let def_uses: FxHashMap<Term, Vec<Term>> = {
let mut t = FxHashMap::default();
for (d, u) in def_uses {
t.entry(d).or_insert_with(|| Vec::new()).push(u);
t.entry(d).or_insert_with(Vec::new).push(u);
}
t
};
@@ -232,7 +230,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
// c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1
term_vars
.get(&(use_.clone(), *to_ty))
.map(|t_to| ilp.new_constraint(c.0 >> t_from.0 + t_to.0 - 1.0))
.map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)))
})
});
}
@@ -245,9 +243,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
.values()
.map(|(a, b)| (a, b))
.chain(term_vars.values().map(|(a, b, _)| (a, b)))
.fold(0.0.into(), |acc: Expression, (v, cost)| {
acc + v.clone() * *cost
}),
.fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost),
);
let (_opt, solution) = ilp.default_solve().unwrap();

View File

@@ -21,9 +21,9 @@ pub const SHARE_TYPES: [ShareType; 3] = [ShareType::Arithmetic, ShareType::Boole
impl ShareType {
fn char(&self) -> char {
match self {
&ShareType::Arithmetic => 'a',
&ShareType::Yao => 'y',
&ShareType::Boolean => 'b',
ShareType::Arithmetic => 'a',
ShareType::Yao => 'y',
ShareType::Boolean => 'b',
}
}
}
@@ -36,7 +36,7 @@ pub fn all_boolean_sharing(c: &Computation) -> SharingMap {
c.outputs
.iter()
.flat_map(|output| {
PostOrderIter::new(output.clone()).map(|term| (term.clone(), ShareType::Boolean))
PostOrderIter::new(output.clone()).map(|term| (term, ShareType::Boolean))
})
.collect()
}

View File

@@ -15,6 +15,12 @@ pub struct ABY {
output: Vec<String>,
}
impl Default for ABY {
fn default() -> Self {
Self::new()
}
}
impl ABY {
/// Initialize ABY circuit
pub fn new() -> Self {

View File

@@ -19,16 +19,16 @@ fn get_filename(path_buf: PathBuf) -> String {
/// In ABY examples, remove the existing directory and create a directory
/// in order to write the new test case
fn create_dir_in_aby(filename: &String) {
let path = format!("third_party/ABY/src/examples/{}", *filename);
let _ = fs::remove_dir_all(path.clone());
fs::create_dir_all(format!("{}/common", path.clone())).expect("Failed to create directory");
fn create_dir_in_aby(filename: &str) {
let path = format!("third_party/ABY/src/examples/{}", filename);
let _ = fs::remove_dir_all(&path);
fs::create_dir_all(format!("{}/common", path)).expect("Failed to create directory");
}
/// Update the CMake file in ABY
fn update_cmake_file(filename: &String) {
fn update_cmake_file(filename: &str) {
let cmake_filename = "third_party/ABY/src/examples/CMakeLists.txt";
let file = File::open(cmake_filename.clone()).expect("Failed to open cmake file");
let file = File::open(&cmake_filename).expect("Failed to open cmake file");
let reader = BufReader::new(file);
let mut flag = false;
@@ -46,70 +46,68 @@ fn update_cmake_file(filename: &String) {
.open(cmake_filename)
.unwrap();
writeln!(file, "{}", format!("add_subdirectory({})", *filename))
writeln!(file, "{}", format!("add_subdirectory({})", filename))
.expect("Failed to write to cmake file");
}
}
/// Create a CMake file for the corresponding filename (testcase)
/// in the ABY examples directory
fn write_test_cmake_file(filename: &String) {
let path = format!("third_party/ABY/src/examples/{}/CMakeLists.txt", *filename);
fn write_test_cmake_file(filename: &str) {
let path = format!("third_party/ABY/src/examples/{}/CMakeLists.txt", filename);
fs::write(
path.clone(),
&path,
format!(
concat!(
"add_executable({}_test {}_test.cpp common/{}.cpp)\n",
"target_link_libraries({}_test ABY::aby ENCRYPTO_utils::encrypto_utils)"
),
*filename, *filename, *filename, *filename
filename, filename, filename, filename
),
)
.expect("Failed to write to cmake file");
}
/// Write the testcase in the ABY examples directory
fn write_test_file(filename: &String) {
fn write_test_file(filename: &str) {
let template = fs::read_to_string("third_party/ABY_templates/test_template.txt")
.expect("Unable to read file");
let path = format!(
"third_party/ABY/src/examples/{}/{}_test.cpp",
*filename, *filename
filename, filename
);
fs::write(path.clone(), template.replace("{fn}", &*filename))
.expect("Failed to write to test file");
fs::write(&path, template.replace("{fn}", filename)).expect("Failed to write to test file");
}
/// Using the h_template.txt, write the .h file for the new test case
fn write_h_file(filename: &String) {
fn write_h_file(filename: &str) {
let template = fs::read_to_string("third_party/ABY_templates/h_template.txt")
.expect("Unable to read file");
let path = format!(
"third_party/ABY/src/examples/{}/common/{}.h",
*filename, *filename
filename, filename
);
fs::write(path.clone(), template.replace("{fn}", &*filename))
.expect("Failed to write to h file");
fs::write(&path, template.replace("{fn}", &*filename)).expect("Failed to write to h file");
}
/// Using the cpp_template.txt, write the .cpp file for the new test case
fn write_circ_file(filename: &String, circ: String, output: &String) {
fn write_circ_file(filename: &str, circ: &str, output: &str) {
let template = fs::read_to_string("third_party/ABY_templates/cpp_template.txt")
.expect("Unable to read file");
let path = format!(
"third_party/ABY/src/examples/{}/common/{}.cpp",
*filename, *filename
filename, filename
);
fs::write(
path.clone(),
&path,
template
.replace("{fn}", &*filename)
.replace("{circ}", &circ)
.replace("{output}", &output),
.replace("{fn}", filename)
.replace("{circ}", circ)
.replace("{output}", output),
)
.expect("Failed to write to cpp file");
}
@@ -123,5 +121,5 @@ pub fn write_aby_exec(aby: ABY, path_buf: PathBuf) {
write_test_file(&filename);
write_h_file(&filename);
let circ_str = aby.setup.join("\n\t") + &aby.circs.join("\n\t");
write_circ_file(&filename, circ_str, &aby.output.join("\n\t"));
write_circ_file(&filename, &circ_str, &aby.output.join("\n\t"));
}

View File

@@ -35,7 +35,7 @@ impl ToABY {
md: metadata,
inputs: TermMap::new(),
cache: TermMap::new(),
s_map: s_map,
s_map,
}
}
@@ -48,7 +48,7 @@ impl ToABY {
/// Parse variable name from IR representation of a variable
fn parse_var_name(&self, full_name: String) -> String {
let parsed: Vec<String> = full_name.split("_").map(str::to_string).collect();
let parsed: Vec<String> = full_name.split('_').map(str::to_string).collect();
if parsed.len() < 2 {
panic!("Invalid variable name: {}", full_name);
}
@@ -90,28 +90,26 @@ impl ToABY {
fn add_cons_gate(&self, t: Term) -> String {
let name = ToABY::get_var_name(t.clone());
let s_circ = self.get_sharetype_circ(t.clone());
let s_circ = self.get_sharetype_circ(t);
format!(
"s_{} = {}->PutCONSGate((uint64_t){}, bitlen);",
name, s_circ, name
)
.to_string()
}
fn add_in_gate(&self, t: Term, role: String) -> String {
let name = ToABY::get_var_name(t.clone());
let s_circ = self.get_sharetype_circ(t.clone());
let s_circ = self.get_sharetype_circ(t);
format!(
"\ts_{} = {}->PutINGate({}, bitlen, {});",
name, s_circ, name, role
)
.to_string()
}
fn add_dummy_gate(&self, t: Term) -> String {
let name = ToABY::get_var_name(t.clone());
let s_circ = self.get_sharetype_circ(t.clone());
format!("\ts_{} = {}->PutDummyINGate(bitlen);", name, s_circ).to_string()
let s_circ = self.get_sharetype_circ(t);
format!("\ts_{} = {}->PutDummyINGate(bitlen);", name, s_circ)
}
/// Initialize private and public inputs from each party
@@ -177,13 +175,13 @@ impl ToABY {
/// Return constant gate evaluating to 0
#[allow(dead_code)]
fn zero() -> String {
format!("bcirc->PutCONSGate((uint64_t)0, (uint32_t)1)")
fn zero() -> &'static str {
"bcirc->PutCONSGate((uint64_t)0, (uint32_t)1)"
}
/// Return constant gate evaluating to 1
fn one() -> String {
format!("bcirc->PutCONSGate((uint64_t)1, (uint32_t)1)")
fn one() -> &'static str {
"bcirc->PutCONSGate((uint64_t)1, (uint32_t)1)"
}
fn remove_cons_gate(&self, circ: String) -> String {
@@ -191,7 +189,7 @@ impl ToABY {
circ.split("PutCONSGate(")
.last()
.unwrap_or("")
.split(",")
.split(',')
.next()
.unwrap_or("")
.to_string()
@@ -204,8 +202,8 @@ impl ToABY {
let s_circ = self.get_sharetype_circ(t.clone());
match check(&a) {
Sort::Bool => {
let a_circ = self.get_bool(&a).clone();
let b_circ = self.get_bool(&b).clone();
let a_circ = self.get_bool(&a);
let b_circ = self.get_bool(&b);
let a_conv = self.add_conv_gate(t.clone(), a, a_circ);
let b_conv = self.add_conv_gate(t.clone(), b, b_circ);
@@ -218,12 +216,12 @@ impl ToABY {
b_conv,
ToABY::one()
);
self.cache.insert(t.clone(), EmbeddedTerm::Bool(s.clone()));
self.cache.insert(t, EmbeddedTerm::Bool(s.clone()));
s
}
Sort::BitVector(_) => {
let a_circ = self.get_bv(&a).clone();
let b_circ = self.get_bv(&b).clone();
let a_circ = self.get_bv(&a);
let b_circ = self.get_bv(&b);
let a_conv = self.add_conv_gate(t.clone(), a, a_circ);
let b_conv = self.add_conv_gate(t.clone(), b, b_circ);
@@ -232,7 +230,7 @@ impl ToABY {
"{}->PutXORGate({}->PutXORGate({}->PutGTGate({}, {}), {}->PutGTGate({}, {})), {})",
s_circ, s_circ, s_circ, a_conv, b_conv, s_circ, b_conv, a_conv, ToABY::one()
);
self.cache.insert(t.clone(), EmbeddedTerm::Bool(s.clone()));
self.cache.insert(t, EmbeddedTerm::Bool(s.clone()));
s
}
e => panic!("Unimplemented sort for Eq: {:?}", e),
@@ -278,9 +276,9 @@ impl ToABY {
let _s = self.embed_eq(t.clone(), t.cs[0].clone(), t.cs[1].clone());
}
Op::Ite => {
let sel_circ = self.get_bool(&t.cs[0]).clone();
let a_circ = self.get_bool(&t.cs[1]).clone();
let b_circ = self.get_bool(&t.cs[2]).clone();
let sel_circ = self.get_bool(&t.cs[0]);
let a_circ = self.get_bool(&t.cs[1]);
let b_circ = self.get_bool(&t.cs[2]);
let sel_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), sel_circ);
let a_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), a_circ);
@@ -306,8 +304,8 @@ impl ToABY {
);
}
Op::BoolNaryOp(o) => {
let a_circ = self.get_bool(&t.cs[0]).clone();
let b_circ = self.get_bool(&t.cs[1]).clone();
let a_circ = self.get_bool(&t.cs[0]);
let b_circ = self.get_bool(&t.cs[1]);
let a_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), a_circ);
let b_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), b_circ);
@@ -422,9 +420,9 @@ impl ToABY {
);
}
Op::Ite => {
let sel_circ = self.get_bool(&t.cs[0]).clone();
let a_circ = self.get_bv(&t.cs[1]).clone();
let b_circ = self.get_bv(&t.cs[2]).clone();
let sel_circ = self.get_bool(&t.cs[0]);
let a_circ = self.get_bv(&t.cs[1]);
let b_circ = self.get_bv(&t.cs[2]);
let sel_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), sel_circ);
let a_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), a_circ);
@@ -554,7 +552,7 @@ impl ToABY {
/// Given a term `t`, lower `t` to ABY Circuits
fn lower(&mut self, t: Term) {
let circ = self.embed(t.clone());
let circ = self.embed(t);
self.aby.circs.push(circ);
}
}

View File

@@ -32,6 +32,12 @@ impl Debug for Ilp {
}
}
impl Default for Ilp {
fn default() -> Self {
Self::new()
}
}
impl Ilp {
/// Create an empty ILP
pub fn new() -> Self {

View File

@@ -2,6 +2,9 @@
//!
//!
// Needed until https://github.com/rust-lang/rust-clippy/pull/8183 is resolved.
#![allow(clippy::identity_op)]
use crate::ir::term::extras::Letified;
use crate::ir::term::*;
use crate::target::ilp::Ilp;
@@ -115,8 +118,8 @@ impl ToMilp {
let mut bounds = Vec::new();
for x in xs {
n += 1;
sum = sum + x;
bounds.push(r.clone() - x << 0);
sum += x;
bounds.push((r.clone() - x) << 0);
}
assert!(n >= 1);
self.ilp.new_constraint(sum << (n as i32 - 1));
@@ -172,8 +175,8 @@ impl ToMilp {
self.bits_are_equal(&a, &b)
}
Sort::BitVector(n) => {
let a = self.get_bv_uint(a).clone();
let b = self.get_bv_uint(b).clone();
let a = self.get_bv_uint(a);
let b = self.get_bv_uint(b);
self.bv_cmp_eq(&a, &b, n)
}
s => panic!("Unimplemented sort for Eq: {:?}", s),
@@ -259,18 +262,18 @@ impl ToMilp {
}
Op::Ite => {
let c = self.get_bool(&bv.cs[0]).clone();
let t = self.get_bv_uint(&bv.cs[1]).clone();
let f = self.get_bv_uint(&bv.cs[2]).clone();
let t = self.get_bv_uint(&bv.cs[1]);
let f = self.get_bv_uint(&bv.cs[2]);
let ite = self.bv_ite(&c, &t, &f, n);
self.set_bv_uint(bv, ite, n);
}
Op::BvUnOp(BvUnOp::Not) => {
let bits = self.get_bv_bits(&bv.cs[0]).clone();
let bits = self.get_bv_bits(&bv.cs[0]);
let not_bits = bits.iter().map(|bit| self.bit_not(bit)).collect();
self.set_bv_bits(bv, not_bits);
}
Op::BvUnOp(BvUnOp::Neg) => {
let x = self.get_bv_uint(&bv.cs[0]).clone();
let x = self.get_bv_uint(&bv.cs[0]);
// Wrong for x == 0
let almost_neg_x = 2f64.powi(n as i32) - x.clone();
let is_zero = self.bv_cmp_eq(&x, &0.into(), n);
@@ -283,15 +286,14 @@ impl ToMilp {
let ext_bits = std::iter::repeat(Expression::from(0)).take(*extra_n);
self.set_bv_bits(bv, bits.into_iter().chain(ext_bits).collect());
} else {
let x = self.get_bv_uint(&bv.cs[0]).clone();
let x = self.get_bv_uint(&bv.cs[0]);
self.set_bv_uint(bv, x, n);
}
}
Op::BvSext(extra_n) => {
let mut bits = self.get_bv_bits(&bv.cs[0]).into_iter().rev();
let ext_bits =
std::iter::repeat(bits.next().expect("sign ext empty").clone())
.take(extra_n + 1);
let ext_bits = std::iter::repeat(bits.next().expect("sign ext empty"))
.take(extra_n + 1);
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
}
@@ -307,7 +309,7 @@ impl ToMilp {
.map(|c| self.get_bv_bits(c))
.collect::<Vec<_>>();
let mut bits_bv_idx: Vec<Vec<Expression>> = Vec::new();
while bits_by_bv[0].len() > 0 {
while !bits_by_bv[0].is_empty() {
bits_bv_idx.push(
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
);
@@ -327,7 +329,7 @@ impl ToMilp {
let values = bv
.cs
.iter()
.map(|c| self.get_bv_uint(c).clone())
.map(|c| self.get_bv_uint(c))
.collect::<Vec<_>>();
let r = match o {
BvNaryOp::Add => self.bv_add(&values, n),
@@ -463,18 +465,18 @@ impl ToMilp {
let r = self.fresh_bv("bv_ite", n_bits);
let m = bv_modulus(n_bits);
self.ilp
.new_constraint(r.clone() - a.clone() - m * (1 - s.clone()) << 0);
.new_constraint((r.clone() - a.clone() - m * (1 - s.clone())) << 0);
self.ilp
.new_constraint(a.clone() - r.clone() - m * (1 - s.clone()) << 0);
.new_constraint((a.clone() - r.clone() - m * (1 - s.clone())) << 0);
self.ilp
.new_constraint(r.clone() - b.clone() - m * s.clone() << 0);
.new_constraint((r.clone() - b.clone() - m * s.clone()) << 0);
self.ilp
.new_constraint(b.clone() - r.clone() - m * s.clone() << 0);
.new_constraint((b.clone() - r.clone() - m * s.clone()) << 0);
r
}
/// [Equations 7](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=915055).
fn bv_bin_mul<'a>(&mut self, a: &Expression, b: &Expression, n_bits: usize) -> Expression {
fn bv_bin_mul(&mut self, a: &Expression, b: &Expression, n_bits: usize) -> Expression {
debug!("({:?}) * ({:?})", a, b);
let a_bits = self.bit_decomp(a, n_bits);
let bit_prods: Vec<_> = a_bits
@@ -512,9 +514,9 @@ impl ToMilp {
let s = self.fresh_bit("bv_le");
let m = bv_modulus(n_bits);
self.ilp
.new_constraint(a.clone() - b.clone() - m * (1 - s.clone()) << -1);
.new_constraint((a.clone() - b.clone() - m * (1 - s.clone())) << -1);
self.ilp
.new_constraint(a.clone() - b.clone() + m * s.clone() >> 0);
.new_constraint((a.clone() - b.clone() + m * s.clone()) >> 0);
s
}
@@ -531,12 +533,12 @@ impl ToMilp {
let a = if signed {
self.get_bv_signed_int(a)
} else {
self.get_bv_uint(a).clone()
self.get_bv_uint(a)
};
let b = if signed {
self.get_bv_signed_int(b)
} else {
self.get_bv_uint(b).clone()
self.get_bv_uint(b)
};
if strict {
self.bv_cmp_lt(&b, &a, w)
@@ -571,7 +573,7 @@ impl ToMilp {
.get(t)
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
{
EmbeddedTerm::Bool(b) => &b,
EmbeddedTerm::Bool(b) => b,
_ => panic!("Non-bool for {:?}", t),
}
}
@@ -614,7 +616,7 @@ impl ToMilp {
}
fn bv_has_bits(&self, t: &Term) -> bool {
self.get_bv(t).borrow().bits.len() > 0
!self.get_bv(t).borrow().bits.is_empty()
}
fn get_bv_uint(&self, t: &Term) -> Expression {
@@ -622,14 +624,14 @@ impl ToMilp {
}
fn get_bv_signed_int(&mut self, t: &Term) -> Expression {
let bits = self.get_bv_bits(t).clone();
let bits = self.get_bv_bits(t);
self.debitify(bits.into_iter(), true)
}
fn get_bv_bits(&mut self, t: &Term) -> Vec<Expression> {
let entry_rc = self.get_bv(t);
let mut entry = entry_rc.borrow_mut();
if entry.bits.len() == 0 {
if entry.bits.is_empty() {
entry.bits = self.bit_decomp(&entry.uint, entry.width);
}
entry.bits.clone()
@@ -644,7 +646,7 @@ impl ToMilp {
}
fn bv_modulus(n_bits: usize) -> f64 {
2.0f64.powi(n_bits.try_into().unwrap()).into()
2.0f64.powi(n_bits.try_into().unwrap())
}
/// Convert this (IR) constraint system `cs` to an MILP.
@@ -663,7 +665,7 @@ pub fn to_ilp(cs: Computation) -> Ilp {
converter.ilp.maximize(converter.get_bool(&opt).clone());
}
Sort::BitVector(_) => {
converter.ilp.maximize(converter.get_bv_uint(&opt).clone());
converter.ilp.maximize(converter.get_bv_uint(&opt));
}
s => panic!("Cannot optimize term of sort {}", s),
};

View File

@@ -39,7 +39,7 @@ fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
for (v, c) in &lc.monomials {
// ditto
if c != &0 {
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
lc_bellman = lc_bellman + (int_to_ff(c), *vars.get(v).unwrap());
}
}
lc_bellman
@@ -49,7 +49,7 @@ fn modulus_as_int<F: PrimeFieldBits>() -> Integer {
let mut bits = F::char_le_bits().to_bitvec();
let mut acc = Integer::from(0);
while let Some(b) = bits.pop() {
acc = acc << 1;
acc <<= 1;
acc += b as u8;
}
acc
@@ -133,7 +133,7 @@ pub fn parse_instance<P: AsRef<Path>, F: PrimeField>(path: P) -> Vec<F> {
f.lines()
.map(|line| {
let s = line.unwrap();
let i = Integer::from_str(&s.trim()).unwrap();
let i = Integer::from_str(s.trim()).unwrap();
int_to_ff(&i)
})
.collect()

View File

@@ -36,7 +36,7 @@ pub struct Lc {
impl Lc {
/// Is this the zero combination?
pub fn is_zero(&self) -> bool {
self.monomials.len() == 0 && &self.constant == &0
self.monomials.is_empty() && self.constant == 0
}
/// Make this the zero combination.
pub fn clear(&mut self) {
@@ -55,7 +55,7 @@ impl Lc {
}
/// Is this a constant? If so, return that constant.
pub fn as_const(&self) -> Option<&Integer> {
(self.monomials.len() == 0).then(|| &self.constant)
self.monomials.is_empty().then(|| &self.constant)
}
}
@@ -187,7 +187,7 @@ impl std::ops::Neg for Lc {
fn neg(mut self) -> Lc {
self.constant = -self.constant;
self.constant.rem_floor_assign(&*self.modulus);
for (_, v) in &mut self.monomials {
for v in &mut self.monomials.values_mut() {
*v *= Integer::from(-1);
v.rem_floor_assign(&*self.modulus);
}
@@ -210,7 +210,7 @@ impl std::ops::MulAssign<&Integer> for Lc {
if other == &Integer::from(0) {
self.monomials.clear();
} else {
for (_, v) in &mut self.monomials {
for v in &mut self.monomials.values_mut() {
*v *= other;
v.rem_floor_assign(&*self.modulus);
}
@@ -233,7 +233,7 @@ impl std::ops::MulAssign<isize> for Lc {
if other == 0 {
self.monomials.clear();
} else {
for (_, v) in &mut self.monomials {
for v in &mut self.monomials.values_mut() {
*v *= Integer::from(other);
v.rem_floor_assign(&*self.modulus);
}
@@ -283,7 +283,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
let n = self.next_idx;
self.next_idx += 1;
self.signal_idxs.insert(s.clone(), n);
self.idxs_signals.insert(n, s.clone());
self.idxs_signals.insert(n, s);
match (self.values.as_mut(), v) {
(Some(vs), Some(v)) => {
//println!("{} -> {}", &s, &v);
@@ -331,7 +331,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
let sign = |i: &Integer| if i < &half_m { "+" } else { "-" };
let format_i = |i: &Integer| format!("{}{}", sign(i), abs(i));
s.extend(format_i(&Integer::from(&a.constant)).chars());
s.push_str(&format_i(&Integer::from(&a.constant)));
for (idx, coeff) in &a.monomials {
s.extend(
format!(
@@ -361,7 +361,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
let av = self.eval(a).unwrap();
let bv = self.eval(b).unwrap();
let cv = self.eval(c).unwrap();
if &((av.clone() * &bv).rem_floor(&*self.modulus)) != &cv {
if (av.clone() * &bv).rem_floor(&*self.modulus) != cv {
panic!(
"Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})",
self.format_lc(a),

View File

@@ -112,12 +112,11 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
);
self.clear_constraint(con_id);
for use_id in self.uses[&var].clone() {
if self.sub_in(var, &lc, use_id) {
if self.r1cs.constraints[use_id].0.is_zero()
|| self.r1cs.constraints[use_id].1.is_zero()
{
self.queue.push(use_id);
}
if self.sub_in(var, &lc, use_id)
&& (self.r1cs.constraints[use_id].0.is_zero()
|| self.r1cs.constraints[use_id].1.is_zero())
{
self.queue.push(use_id);
}
}
debug_assert_eq!(0, self.uses[&var].len());

View File

@@ -13,10 +13,9 @@ use rug::ops::Pow;
use rug::Integer;
use std::cell::RefCell;
use std::rc::Rc;
use std::fmt::Display;
use std::iter::ExactSizeIterator;
use std::rc::Rc;
struct BvEntry {
width: usize,
@@ -97,7 +96,7 @@ impl ToR1cs {
if x == 0 {
Integer::from(0)
} else {
Integer::from(x.invert(&self.r1cs.modulus()).unwrap())
x.invert(self.r1cs.modulus()).unwrap()
}
}),
false,
@@ -179,7 +178,7 @@ impl ToR1cs {
/// Constrains `x` to fit in `n` (`signed`) bits.
/// The LSB is at index 0.
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Vec<Lc> {
debug!("Bitify({}): {}", n, self.r1cs.format_lc(&x));
debug!("Bitify({}): {}", n, self.r1cs.format_lc(x));
let bits = self.decomp(d, x, n);
let sum = self.debitify(bits.iter().cloned(), signed);
self.assert_zero(sum - x);
@@ -246,9 +245,11 @@ impl ToR1cs {
fn nary_and<I: ExactSizeIterator<Item = Lc>>(&mut self, mut xs: I) -> Lc {
let n = xs.len();
if n <= 3 {
let first = xs.next().expect("empty AND").clone();
let first = xs.next().expect("empty AND");
xs.fold(first, |a, x| self.mul(a, x))
} else {
// Needed to end the closures borrow of self, before the next line.
#[allow(clippy::needless_collect)]
let negs: Vec<Lc> = xs.map(|x| self.bool_not(&x)).collect();
let a = self.nary_or(negs.into_iter());
self.bool_not(&a)
@@ -259,6 +260,8 @@ impl ToR1cs {
fn nary_or<I: ExactSizeIterator<Item = Lc>>(&mut self, xs: I) -> Lc {
let n = xs.len();
if n <= 3 {
// Needed to end the closures borrow of self, before the next line.
#[allow(clippy::needless_collect)]
let negs: Vec<Lc> = xs.map(|x| self.bool_not(&x)).collect();
let a = self.nary_and(negs.into_iter());
self.bool_not(&a)
@@ -281,6 +284,8 @@ impl ToR1cs {
debug!("Embed op: {}", c.op);
// Handle field access once and for all
if let Op::Field(i) = &c.op {
// Need to borrow self in between search and insert. Could refactor.
#[allow(clippy::map_entry)]
if !self.cache.contains_key(&c) {
let t = self.get_field(&c.cs[0], *i);
self.cache.insert(c, t);
@@ -297,7 +302,8 @@ impl ToR1cs {
self.embed_pf(c);
}
Sort::Tuple(_) => {
self.embed_tuple(c);
// custom ops?
panic!("Cannot embed tuple term: {}", c)
}
s => panic!("Unsupported sort in embed: {:?}", s),
}
@@ -305,18 +311,6 @@ impl ToR1cs {
}
}
#[allow(unreachable_code)]
#[allow(unused_variables)]
fn embed_tuple(&mut self, a: Term) {
if !self.cache.contains_key(&a) {
let t = match &a.op {
// May want to support cunstor operators here...
_ => panic!("Cannot embed tuple term: {}", a),
};
self.cache.insert(a, t);
}
}
fn get_field(&self, tuple_term: &Term, field: usize) -> EmbeddedTerm {
match self.cache.get(tuple_term) {
Some(EmbeddedTerm::Tuple(v)) => v[field].clone(),
@@ -332,8 +326,8 @@ impl ToR1cs {
self.bits_are_equal(&a, &b)
}
Sort::BitVector(_) => {
let a = self.get_bv_uint(a).clone();
let b = self.get_bv_uint(b).clone();
let a = self.get_bv_uint(a);
let b = self.get_bv_uint(b);
self.are_equal(a, &b)
}
Sort::Field(_) => {
@@ -344,8 +338,7 @@ impl ToR1cs {
Sort::Tuple(sorts) => {
let n = sorts.len();
let eqs: Vec<Term> = (0..n).map(|i| {
let t = term![Op::Eq; term![Op::Field(i); a.clone()], term![Op::Field(i); b.clone()]];
t
term![Op::Eq; term![Op::Field(i); a.clone()], term![Op::Field(i); b.clone()]]
}).collect();
let conj = term(Op::BoolNaryOp(BoolNaryOp::And), eqs);
self.embed(conj.clone());
@@ -463,16 +456,16 @@ impl ToR1cs {
fn assert_bool(&mut self, t: &Term) {
//println!("Embed: {}", c);
// TODO: skip if already embedded
if &t.op == &Op::Eq {
if t.op == Op::Eq {
t.cs.iter().for_each(|c| self.embed(c.clone()));
self.assert_eq(&t.cs[0], &t.cs[1]);
} else if &t.op == &AND {
} else if t.op == AND {
for c in &t.cs {
self.assert_bool(c);
}
} else {
self.embed(t.clone());
let lc = self.get_bool(&t).clone();
let lc = self.get_bool(t).clone();
self.assert_zero(lc - 1);
}
}
@@ -489,12 +482,12 @@ impl ToR1cs {
let a = if signed {
self.get_bv_signed_int(a)
} else {
self.get_bv_uint(a).clone()
self.get_bv_uint(a)
};
let b = if signed {
self.get_bv_signed_int(b)
} else {
self.get_bv_uint(b).clone()
self.get_bv_uint(b)
};
// Use the fact: a > b <=> a - 1 >= b
self.bv_ge(if strict { a - 1 } else { a }, &b, w)
@@ -551,18 +544,18 @@ impl ToR1cs {
}
Op::Ite => {
let c = self.get_bool(&bv.cs[0]).clone();
let t = self.get_bv_uint(&bv.cs[1]).clone();
let f = self.get_bv_uint(&bv.cs[2]).clone();
let t = self.get_bv_uint(&bv.cs[1]);
let f = self.get_bv_uint(&bv.cs[2]);
let ite = self.ite(c, t, &f);
self.set_bv_uint(bv, ite, n);
}
Op::BvUnOp(BvUnOp::Not) => {
let bits = self.get_bv_bits(&bv.cs[0]).clone();
let bits = self.get_bv_bits(&bv.cs[0]);
let not_bits = bits.iter().map(|bit| self.bool_not(bit)).collect();
self.set_bv_bits(bv, not_bits);
}
Op::BvUnOp(BvUnOp::Neg) => {
let x = self.get_bv_uint(&bv.cs[0]).clone();
let x = self.get_bv_uint(&bv.cs[0]);
// Wrong for x == 0
let almost_neg_x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - &x;
let is_zero = self.is_zero(x);
@@ -575,15 +568,14 @@ impl ToR1cs {
let ext_bits = std::iter::repeat(self.r1cs.zero()).take(*extra_n);
self.set_bv_bits(bv, bits.into_iter().chain(ext_bits).collect());
} else {
let x = self.get_bv_uint(&bv.cs[0]).clone();
let x = self.get_bv_uint(&bv.cs[0]);
self.set_bv_uint(bv, x, n);
}
}
Op::BvSext(extra_n) => {
let mut bits = self.get_bv_bits(&bv.cs[0]).into_iter().rev();
let ext_bits =
std::iter::repeat(bits.next().expect("sign ext empty").clone())
.take(extra_n + 1);
let ext_bits = std::iter::repeat(bits.next().expect("sign ext empty"))
.take(extra_n + 1);
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
}
@@ -604,7 +596,7 @@ impl ToR1cs {
.map(|c| self.get_bv_bits(c))
.collect::<Vec<_>>();
let mut bits_bv_idx: Vec<Vec<Lc>> = Vec::new();
while bits_by_bv[0].len() > 0 {
while !bits_by_bv[0].is_empty() {
bits_bv_idx.push(
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
);
@@ -624,7 +616,7 @@ impl ToR1cs {
let values = bv
.cs
.iter()
.map(|c| self.get_bv_uint(c).clone())
.map(|c| self.get_bv_uint(c))
.collect::<Vec<_>>();
let (res, width) = match o {
BvNaryOp::Add => {
@@ -663,14 +655,12 @@ impl ToR1cs {
let b = self.get_bv_uint(&bv.cs[1]);
match o {
BvBinOp::Sub => {
let sum = a.clone() + &(Integer::from(1) << n as u32) - &b;
let sum = a + &(Integer::from(1) << n as u32) - &b;
let mut bits = self.bitify("sub", &sum, n + 1, false);
bits.truncate(n);
self.set_bv_bits(bv, bits);
}
BvBinOp::Udiv | BvBinOp::Urem => {
let b = b.clone();
let a = a.clone();
let is_zero = self.is_zero(b.clone());
let (q_v, r_v) = self
.r1cs
@@ -705,8 +695,7 @@ impl ToR1cs {
}
// Shift cases
_ => {
let r = b.clone();
let a = a.clone();
let r = b;
let b = bitsize(n - 1);
assert!(1 << b == n);
let mut rb = self.get_bv_bits(&bv.cs[1]);
@@ -776,7 +765,7 @@ impl ToR1cs {
.get(t)
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
{
EmbeddedTerm::Bool(b) => &b,
EmbeddedTerm::Bool(b) => b,
_ => panic!("Non-boolean for {:?}", t),
}
}
@@ -818,7 +807,7 @@ impl ToR1cs {
}
fn bv_has_bits(&self, t: &Term) -> bool {
self.get_bv(t).borrow().bits.len() > 0
!self.get_bv(t).borrow().bits.is_empty()
}
fn get_bv_uint(&self, t: &Term) -> Lc {
@@ -826,14 +815,14 @@ impl ToR1cs {
}
fn get_bv_signed_int(&mut self, t: &Term) -> Lc {
let bits = self.get_bv_bits(t).clone();
let bits = self.get_bv_bits(t);
self.debitify(bits.into_iter(), true)
}
fn get_bv_bits(&mut self, t: &Term) -> Vec<Lc> {
let entry_rc = self.get_bv(t);
let mut entry = entry_rc.borrow_mut();
if entry.bits.len() == 0 {
if entry.bits.is_empty() {
entry.bits = self.bitify("getbits", &entry.uint, entry.width, false);
}
entry.bits.clone()
@@ -871,14 +860,16 @@ impl ToR1cs {
match o {
PfNaryOp::Add => args.fold(self.r1cs.zero(), std::ops::Add::add),
PfNaryOp::Mul => {
// Needed to end the above closures borrow of self, before the mul call
#[allow(clippy::needless_collect)]
let args = args.cloned().collect::<Vec<_>>();
let mut args_iter = args.into_iter();
let first = args_iter.next().unwrap();
args_iter.fold(first, |a, b| self.mul(a, b.clone()))
args_iter.fold(first, |a, b| self.mul(a, b))
}
}
}
Op::UbvToPf(_) => self.get_bv_uint(&c.cs[0]).clone(),
Op::UbvToPf(_) => self.get_bv_uint(&c.cs[0]),
Op::PfUnOp(PfUnOp::Neg) => -self.get_pf(&c.cs[0]).clone(),
Op::PfUnOp(PfUnOp::Recip) => {
let x = self.get_pf(&c.cs[0]).clone();

View File

@@ -24,7 +24,7 @@ struct SmtDisp<'a, T>(pub &'a T);
impl<'a, T: Expr2Smt<()> + 'a> Display for SmtDisp<'a, T> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut s = Vec::new();
<T as Expr2Smt<()>>::expr_to_smt2(&self.0, &mut s, ()).unwrap();
<T as Expr2Smt<()>>::expr_to_smt2(self.0, &mut s, ()).unwrap();
write!(f, "{}", std::str::from_utf8(&s).unwrap())?;
Ok(())
}
@@ -34,7 +34,7 @@ struct SmtSortDisp<'a, T>(pub &'a T);
impl<'a, T: Sort2Smt + 'a> Display for SmtSortDisp<'a, T> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut s = Vec::new();
<T as Sort2Smt>::sort_to_smt2(&self.0, &mut s).unwrap();
<T as Sort2Smt>::sort_to_smt2(self.0, &mut s).unwrap();
write!(f, "{}", std::str::from_utf8(&s).unwrap())?;
Ok(())
}

View File

@@ -14,6 +14,12 @@ pub struct OnceQueue<T> {
set: FxHashSet<T>,
}
impl<T: Eq + Hash + Clone> Default for OnceQueue<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Eq + Hash + Clone> OnceQueue<T> {
/// Add to the queue. If `t` is already present, it is dropped.
pub fn push(&mut self, t: T) {