mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
Resolve lints and add clippy to CI (#35)
`front::zokrates` is currently excluded
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -29,6 +29,8 @@ jobs:
|
||||
run: cargo check --verbose
|
||||
- name: Check format
|
||||
run: cargo fmt -- --check
|
||||
- name: Lint
|
||||
run: cargo clippy
|
||||
- name: Build
|
||||
run: cargo build --verbose && make build
|
||||
- name: Run tests
|
||||
|
||||
@@ -98,7 +98,7 @@ impl MemManager {
|
||||
let alloc = Alloc::new(id, *addr_width, *val_width, *size);
|
||||
let v = alloc.var().clone();
|
||||
if let Op::Var(n, _) = &v.op {
|
||||
self.cs.borrow_mut().eval_and_save(&n, &array);
|
||||
self.cs.borrow_mut().eval_and_save(n, &array);
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
@@ -152,7 +152,7 @@ impl MemManager {
|
||||
alloc.next_var();
|
||||
let v = alloc.var().clone();
|
||||
if let Op::Var(n, _) = &v.op {
|
||||
self.cs.borrow_mut().eval_and_save(&n, &new);
|
||||
self.cs.borrow_mut().eval_and_save(n, &new);
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
@@ -267,7 +267,8 @@ impl<Ty: Display> FnFrame<Ty> {
|
||||
}
|
||||
fn enter_condition(&mut self, condition: Term) -> Result<()> {
|
||||
if check(&condition) == Sort::Bool {
|
||||
Ok(self.stack.push(StateEntry::Cond(condition)))
|
||||
self.stack.push(StateEntry::Cond(condition));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CircError::NotBool(condition))
|
||||
}
|
||||
@@ -313,7 +314,7 @@ impl<Ty: Display> FnFrame<Ty> {
|
||||
StateEntry::Cond(c) => break_if.push(c.clone()),
|
||||
StateEntry::Break(name_, ref mut break_conds) => {
|
||||
if name_ == name {
|
||||
break_conds.push(if break_if.len() == 0 {
|
||||
break_conds.push(if break_if.is_empty() {
|
||||
leaf_term(Op::Const(Value::Bool(true)))
|
||||
} else {
|
||||
term(Op::BoolNaryOp(BoolNaryOp::And), break_if)
|
||||
@@ -394,6 +395,8 @@ pub trait Embeddable {
|
||||
|
||||
/// Create a new term for the default return value of a function returning type `ty`.
|
||||
/// The name `ssa_name` is globally unique, and can be used if needed.
|
||||
// Because the type alias may change.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
fn initialize_return(&self, ty: &Self::Ty, ssa_name: &SsaName) -> Self::T;
|
||||
}
|
||||
|
||||
@@ -450,9 +453,9 @@ impl<E: Embeddable> Circify<E> {
|
||||
/// Initialize environment entry binding `name` to `ty`.
|
||||
fn declare_env_name(&mut self, name: VarName, ty: &E::Ty) -> Result<&SsaName> {
|
||||
if let Some(back) = self.fn_stack.last_mut() {
|
||||
back.declare(name.clone(), ty.clone())
|
||||
back.declare(name, ty.clone())
|
||||
} else {
|
||||
self.globals.declare(name.clone(), ty.clone())
|
||||
self.globals.declare(name, ty.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,19 +487,18 @@ impl<E: Embeddable> Circify<E> {
|
||||
/// Returns `None` if it's in the global scope.
|
||||
/// Returns `Some` if it's in a local scope.
|
||||
/// Errors if the name cannot be found.
|
||||
// Because the type alias may change.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
fn mk_abs(&self, name: &VarName) -> Result<Option<ScopeIdx>> {
|
||||
if let Some(fn_) = self.fn_stack.last() {
|
||||
for (lex_i, e) in fn_.stack.iter().enumerate().rev() {
|
||||
match e {
|
||||
StateEntry::Lex(l) => {
|
||||
if l.has_name(name) {
|
||||
return Ok(Some(ScopeIdx {
|
||||
lex: lex_i,
|
||||
fn_: self.fn_stack.len() - 1,
|
||||
}));
|
||||
}
|
||||
if let StateEntry::Lex(l) = e {
|
||||
if l.has_name(name) {
|
||||
return Ok(Some(ScopeIdx {
|
||||
lex: lex_i,
|
||||
fn_: self.fn_stack.len() - 1,
|
||||
}));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -551,7 +553,7 @@ impl<E: Embeddable> Circify<E> {
|
||||
///
|
||||
/// If `public`, then make it a public (fixed) rather than private (existential) circuit input.
|
||||
pub fn declare_init(&mut self, name: VarName, ty: E::Ty, val: Val<E::T>) -> Result<Val<E::T>> {
|
||||
let ssa_name = self.declare_env_name(name.clone(), &ty)?.clone();
|
||||
let ssa_name = self.declare_env_name(name, &ty)?.clone();
|
||||
// TODO: add language-specific coersion here if needed
|
||||
assert!(self.vals.insert(ssa_name, val.clone()).is_none());
|
||||
Ok(val)
|
||||
@@ -566,7 +568,7 @@ impl<E: Embeddable> Circify<E> {
|
||||
ty: &E::Ty,
|
||||
visiblity: Option<PartyId>,
|
||||
) {
|
||||
self.e.assign(&mut self.cir_ctx, &ty, name, term, visiblity);
|
||||
self.e.assign(&mut self.cir_ctx, ty, name, term, visiblity);
|
||||
}
|
||||
|
||||
/// Assign `loc` in the current scope to `val`.
|
||||
@@ -669,7 +671,7 @@ impl<E: Embeddable> Circify<E> {
|
||||
pub fn condition(&self) -> Term {
|
||||
// TODO: more precise conditions, depending on lex scopes.
|
||||
let cs: Vec<_> = self.fn_stack.iter().flat_map(|f| f.conditions()).collect();
|
||||
if cs.len() == 0 {
|
||||
if cs.is_empty() {
|
||||
leaf_term(Op::Const(Value::Bool(true)))
|
||||
} else {
|
||||
term(Op::BoolNaryOp(BoolNaryOp::And), cs)
|
||||
@@ -748,7 +750,7 @@ impl<E: Embeddable> Circify<E> {
|
||||
self.vals
|
||||
.get(l.get_name(&loc.name)?)
|
||||
.cloned()
|
||||
.ok_or_else(|| CircError::InvalidLoc(loc))
|
||||
.ok_or(CircError::InvalidLoc(loc))
|
||||
}
|
||||
|
||||
/// Dereference a reference into a location.
|
||||
@@ -762,6 +764,8 @@ impl<E: Embeddable> Circify<E> {
|
||||
}
|
||||
|
||||
/// Create a reference to a variable.
|
||||
// Because the type alias may change.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
pub fn ref_(&self, name: &VarName) -> Result<Val<E::T>> {
|
||||
let idx = self.mk_abs(name)?;
|
||||
Ok(Val::Ref(Loc {
|
||||
|
||||
@@ -56,7 +56,7 @@ impl<'ast> Gen<'ast> {
|
||||
/// Attempt to enter a funciton.
|
||||
/// Returns `false` if doing so would violate the recursion limit.
|
||||
fn enter_function(&mut self, name: &'ast str, dec_value: Option<Integer>) -> bool {
|
||||
let e = self.stack_by_fn.entry(name).or_insert_with(|| Vec::new());
|
||||
let e = self.stack_by_fn.entry(name).or_insert_with(Vec::new);
|
||||
//assert_eq!(e.last().and_then(|l| l.as_ref()).is_some(), dec_value.is_some());
|
||||
let do_enter = if let (Some(last_val), Some(this_val)) =
|
||||
(e.last().and_then(|l| l.as_ref()), dec_value.as_ref())
|
||||
@@ -86,7 +86,7 @@ impl<'ast> Gen<'ast> {
|
||||
fn register_rules(&mut self, pgm: &'ast ast::Program<'ast>) {
|
||||
for r in &pgm.rules {
|
||||
assert!(!self.rules.contains_key(&r.name.value));
|
||||
self.rules.insert(&r.name.value, r);
|
||||
self.rules.insert(r.name.value, r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ impl<'ast> Gen<'ast> {
|
||||
}
|
||||
},
|
||||
|t, size| {
|
||||
let size = usize::from_str(&size.value).expect("bad array size");
|
||||
let size = usize::from_str(size.value).expect("bad array size");
|
||||
ty::Ty::Array(size, Box::new(t))
|
||||
},
|
||||
),
|
||||
@@ -127,7 +127,7 @@ impl<'ast> Gen<'ast> {
|
||||
let vis = if public { PUBLIC_VIS } else { PROVER_VIS };
|
||||
self.circ.declare(d.ident.value.into(), &ty, public, vis)?;
|
||||
}
|
||||
let r = self.rule_cases(&rule)?;
|
||||
let r = self.rule_cases(rule)?;
|
||||
self.exit_function(name);
|
||||
self.circ.assert(r.as_bool());
|
||||
Ok(())
|
||||
@@ -165,19 +165,19 @@ impl<'ast> Gen<'ast> {
|
||||
/// * `top_level` indicates whether this expression is a top-level expression in a condition.
|
||||
fn expr(&mut self, e: &'ast ast::Expression, top_level: bool) -> Result<'ast, term::T> {
|
||||
match e {
|
||||
&ast::Expression::Binary(ref b) => self.bin_expr(b),
|
||||
&ast::Expression::Unary(ref u) => self.un_expr(u),
|
||||
&ast::Expression::Paren(ref i, _) => self.expr(i, top_level),
|
||||
&ast::Expression::Identifier(ref i) => self.ident(i),
|
||||
&ast::Expression::Literal(ref i) => self.literal(i),
|
||||
&ast::Expression::Access(ref c) => {
|
||||
ast::Expression::Binary(ref b) => self.bin_expr(b),
|
||||
ast::Expression::Unary(ref u) => self.un_expr(u),
|
||||
ast::Expression::Paren(ref i, _) => self.expr(i, top_level),
|
||||
ast::Expression::Identifier(ref i) => self.ident(i),
|
||||
ast::Expression::Literal(ref i) => self.literal(i),
|
||||
ast::Expression::Access(ref c) => {
|
||||
let arr = self.ident(&c.arr)?;
|
||||
c.idxs.iter().try_fold(arr, |arr, idx| {
|
||||
let idx_v = self.expr(idx, false)?;
|
||||
term::array_idx(&arr, &idx_v).map_err(|err| Error::new(err, idx.span().clone()))
|
||||
})
|
||||
}
|
||||
&ast::Expression::Call(ref c) => {
|
||||
ast::Expression::Call(ref c) => {
|
||||
let args = c
|
||||
.args
|
||||
.iter()
|
||||
@@ -202,8 +202,7 @@ impl<'ast> Gen<'ast> {
|
||||
.args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, arg)| arg.dec.is_some())
|
||||
.next()
|
||||
.find(|&(_, arg)| arg.dec.is_some())
|
||||
{
|
||||
let ir = &args[i].ir;
|
||||
let reduced_ir = fold(ir);
|
||||
@@ -225,7 +224,7 @@ impl<'ast> Gen<'ast> {
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
let r = self.rule_cases(&rule)?;
|
||||
let r = self.rule_cases(rule)?;
|
||||
self.exit_function(name);
|
||||
Ok(r)
|
||||
} else {
|
||||
@@ -238,22 +237,22 @@ impl<'ast> Gen<'ast> {
|
||||
}
|
||||
fn literal(&mut self, e: &ast::Literal) -> Result<'ast, term::T> {
|
||||
match e {
|
||||
&ast::Literal::BinLiteral(ref b) => {
|
||||
ast::Literal::BinLiteral(ref b) => {
|
||||
let len = b.value.len() as u8 - 2;
|
||||
let val = u64::from_str_radix(&b.value[2..], 2).unwrap();
|
||||
Ok(term::uint_lit(val, len))
|
||||
}
|
||||
&ast::Literal::HexLiteral(ref b) => {
|
||||
ast::Literal::HexLiteral(ref b) => {
|
||||
let len = (b.value.len() as u8 - 2) * 4;
|
||||
let val = u64::from_str_radix(&b.value[2..], 16).unwrap();
|
||||
Ok(term::uint_lit(val, len))
|
||||
}
|
||||
&ast::Literal::DecimalLiteral(ref b) => {
|
||||
let val = Integer::from_str(&b.value).unwrap();
|
||||
ast::Literal::DecimalLiteral(ref b) => {
|
||||
let val = Integer::from_str(b.value).unwrap();
|
||||
Ok(term::pf_lit(val))
|
||||
}
|
||||
&ast::Literal::BooleanLiteral(ref b) => {
|
||||
let val = bool::from_str(&b.value).unwrap();
|
||||
ast::Literal::BooleanLiteral(ref b) => {
|
||||
let val = bool::from_str(b.value).unwrap();
|
||||
Ok(term::bool_lit(val))
|
||||
}
|
||||
}
|
||||
@@ -310,7 +309,7 @@ impl<'ast> Gen<'ast> {
|
||||
.enumerate()
|
||||
.find(|(_, arg)| arg.dec.is_some())
|
||||
{
|
||||
self.enter_function(&rule.name.value, None);
|
||||
self.enter_function(rule.name.value, None);
|
||||
for d in &rule.args {
|
||||
let (ty, public) = self.ty(&d.ty);
|
||||
let vis = if public { PUBLIC_VIS } else { PROVER_VIS };
|
||||
@@ -330,7 +329,7 @@ impl<'ast> Gen<'ast> {
|
||||
let mut bad_recursion = Vec::new();
|
||||
for atom in &cond.exprs {
|
||||
if let ast::Expression::Call(c) = &atom {
|
||||
if &c.fn_name.value == &rule.name.value {
|
||||
if c.fn_name.value == rule.name.value {
|
||||
let formal_arg = self
|
||||
.circ
|
||||
.get_value(Loc::local(rule.args[arg_idx].ident.value.to_owned()))?
|
||||
@@ -362,7 +361,7 @@ impl<'ast> Gen<'ast> {
|
||||
)?);
|
||||
self.circ.exit_scope();
|
||||
}
|
||||
self.exit_function(&rule.name.value);
|
||||
self.exit_function(rule.name.value);
|
||||
bug_in_rule_if_any
|
||||
.into_iter()
|
||||
.try_fold(term::bool_lit(false), |x, y| {
|
||||
|
||||
@@ -229,7 +229,7 @@ pub mod ast {
|
||||
match self {
|
||||
Expression::Binary(b) => &b.span,
|
||||
Expression::Identifier(i) => &i.span,
|
||||
Expression::Literal(c) => &c.span(),
|
||||
Expression::Literal(c) => c.span(),
|
||||
Expression::Unary(u) => &u.span,
|
||||
Expression::Call(u) => &u.span,
|
||||
Expression::Access(u) => &u.span,
|
||||
@@ -530,7 +530,7 @@ pub mod ast {
|
||||
}
|
||||
|
||||
pub fn parse(file_string: &str) -> Result<ast::Program, Error<Rule>> {
|
||||
let mut pest_pairs = MyParser::parse(Rule::program, &file_string)?;
|
||||
let mut pest_pairs = MyParser::parse(Rule::program, file_string)?;
|
||||
use from_pest::FromPest;
|
||||
Ok(ast::Program::from_pest(&mut pest_pairs).expect("bug in AST construction"))
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ impl T {
|
||||
/// Create a new term, checking that the explicit type and IR type agree.
|
||||
pub fn new(ir: Term, ty: Ty) -> Self {
|
||||
let ir_ty = check(&ir);
|
||||
let res = Self { ir, ty: ty.clone() };
|
||||
let res = Self { ir, ty };
|
||||
Self::check_ty(&ir_ty, &res.ty);
|
||||
res
|
||||
}
|
||||
@@ -43,7 +43,7 @@ impl T {
|
||||
#[track_caller]
|
||||
pub fn as_bool(&self) -> Term {
|
||||
match &self.ty {
|
||||
&Ty::Bool => self.ir.clone(),
|
||||
Ty::Bool => self.ir.clone(),
|
||||
_ => panic!("{} is not a bool", self),
|
||||
}
|
||||
}
|
||||
@@ -395,7 +395,7 @@ impl Embeddable for Datalog {
|
||||
&*inner_ty,
|
||||
idx_name(&raw_name, i),
|
||||
user_name.as_ref().map(|u| idx_name(u, i)),
|
||||
visibility.clone(),
|
||||
visibility,
|
||||
)
|
||||
})
|
||||
.enumerate()
|
||||
@@ -409,10 +409,7 @@ impl Embeddable for Datalog {
|
||||
}
|
||||
fn ite(&self, _ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
|
||||
if t.ty == f.ty {
|
||||
T::new(
|
||||
term![Op::Ite; cond.clone(), t.ir.clone(), f.ir.clone()],
|
||||
t.ty.clone(),
|
||||
)
|
||||
T::new(term![Op::Ite; cond, t.ir, f.ir], t.ty)
|
||||
} else {
|
||||
panic!("Cannot ITE {} and {}", t, f)
|
||||
}
|
||||
@@ -444,6 +441,12 @@ impl Embeddable for Datalog {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Datalog {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Datalog {
|
||||
/// Initialize the Datalog lang def
|
||||
pub fn new() -> Self {
|
||||
|
||||
@@ -18,10 +18,10 @@ pub enum Ty {
|
||||
impl Display for Ty {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match &self {
|
||||
&Ty::Bool => write!(f, "bool"),
|
||||
&Ty::Field => write!(f, "field"),
|
||||
&Ty::Uint(w) => write!(f, "u{}", w),
|
||||
&Ty::Array(l, t) => write!(f, "{}[{}]", t, l),
|
||||
Ty::Bool => write!(f, "bool"),
|
||||
Ty::Field => write!(f, "field"),
|
||||
Ty::Uint(w) => write!(f, "u{}", w),
|
||||
Ty::Array(l, t) => write!(f, "{}[{}]", t, l),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Input language front-ends
|
||||
|
||||
pub mod datalog;
|
||||
#[allow(clippy::all)]
|
||||
pub mod zokrates;
|
||||
|
||||
use super::ir::term::Computation;
|
||||
|
||||
@@ -42,7 +42,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
|
||||
stack.extend(t.cs.iter().map(|c| (c.clone(), false)));
|
||||
continue;
|
||||
}
|
||||
let c_get = |x: &Term| -> Term { cache.get(&x).expect("postorder cache").clone() };
|
||||
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
|
||||
let get = |i: usize| c_get(&t.cs[i]);
|
||||
let new_t_opt = match &t.op {
|
||||
&NOT => get(0).as_bool_opt().and_then(|c| cbool(!c)),
|
||||
@@ -59,7 +59,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
|
||||
Some(bv) => cbool(bv.bit(*i)),
|
||||
_ => None,
|
||||
},
|
||||
Op::BoolNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
|
||||
Op::BoolNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
|
||||
Op::Eq => {
|
||||
let c0 = get(0);
|
||||
let c1 = get(1);
|
||||
@@ -119,7 +119,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
Op::BvNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
|
||||
Op::BvNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
|
||||
Op::BvBinPred(p) => {
|
||||
if let (Some(a), Some(b)) = (get(0).as_bv_opt(), get(1).as_bv_opt()) {
|
||||
Some(leaf_term(Op::Const(Value::Bool(match p {
|
||||
@@ -168,7 +168,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
|
||||
},
|
||||
}
|
||||
}
|
||||
Op::PfNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c).clone()))),
|
||||
Op::PfNaryOp(o) => Some(o.clone().flatten(t.cs.iter().map(|c| c_get(c)))),
|
||||
Op::PfUnOp(o) => get(0).as_pf_opt().map(|pf| {
|
||||
leaf_term(Op::Const(Value::Field(match o {
|
||||
PfUnOp::Recip => pf.clone().recip(),
|
||||
@@ -177,17 +177,17 @@ pub fn fold_cache(node: &Term, cache: &mut TermMap<Term>) -> Term {
|
||||
}),
|
||||
_ => None,
|
||||
};
|
||||
let c_get = |x: &Term| -> Term { cache.get(&x).expect("postorder cache").clone() };
|
||||
let c_get = |x: &Term| -> Term { cache.get(x).expect("postorder cache").clone() };
|
||||
let new_t = new_t_opt
|
||||
.unwrap_or_else(|| term(t.op.clone(), t.cs.iter().map(|c| c_get(c)).collect()));
|
||||
cache.insert(t, new_t);
|
||||
}
|
||||
cache.get(&node).expect("postorder cache").clone()
|
||||
cache.get(node).expect("postorder cache").clone()
|
||||
}
|
||||
|
||||
fn neg_bool(t: Term) -> Term {
|
||||
match &t.op {
|
||||
&NOT => t.cs[0].clone(),
|
||||
match t.op {
|
||||
NOT => t.cs[0].clone(),
|
||||
_ => term![NOT; t],
|
||||
}
|
||||
}
|
||||
@@ -220,7 +220,7 @@ impl NaryFlat<bool> for BoolNaryOp {
|
||||
BoolNaryOp::Or => {
|
||||
if consts.iter().any(|b| *b) {
|
||||
leaf_term(Op::Const(Value::Bool(true)))
|
||||
} else if children.len() == 0 {
|
||||
} else if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::Bool(false)))
|
||||
} else {
|
||||
safe_nary(OR, children)
|
||||
@@ -229,7 +229,7 @@ impl NaryFlat<bool> for BoolNaryOp {
|
||||
BoolNaryOp::And => {
|
||||
if consts.iter().any(|b| !*b) {
|
||||
leaf_term(Op::Const(Value::Bool(false)))
|
||||
} else if children.len() == 0 {
|
||||
} else if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::Bool(true)))
|
||||
} else {
|
||||
safe_nary(AND, children)
|
||||
@@ -237,7 +237,7 @@ impl NaryFlat<bool> for BoolNaryOp {
|
||||
}
|
||||
BoolNaryOp::Xor => {
|
||||
let odd_trues = consts.into_iter().filter(|b| *b).count() % 2 == 1;
|
||||
if children.len() == 0 {
|
||||
if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::Bool(odd_trues)))
|
||||
} else {
|
||||
let t = safe_nary(XOR, children);
|
||||
@@ -264,7 +264,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
|
||||
BvNaryOp::Or => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::BitOr::bitor);
|
||||
if children.len() == 0 {
|
||||
if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::BitVector(c)))
|
||||
} else if c.uint() == &Integer::from(0) {
|
||||
safe_nary(BV_OR, children)
|
||||
@@ -297,7 +297,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
|
||||
BvNaryOp::And => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::BitAnd::bitand);
|
||||
if children.len() == 0 {
|
||||
if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::BitVector(c)))
|
||||
} else {
|
||||
safe_nary(
|
||||
@@ -328,7 +328,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
|
||||
BvNaryOp::Xor => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::BitXor::bitxor);
|
||||
if children.len() == 0 {
|
||||
if children.is_empty() {
|
||||
leaf_term(Op::Const(Value::BitVector(c)))
|
||||
} else {
|
||||
safe_nary(
|
||||
@@ -360,7 +360,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
|
||||
BvNaryOp::Add => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::Add::add);
|
||||
if c.uint() != &Integer::from(0) || children.len() == 0 {
|
||||
if c.uint() != &Integer::from(0) || children.is_empty() {
|
||||
children.push(leaf_term(Op::Const(Value::BitVector(c))));
|
||||
}
|
||||
}
|
||||
@@ -372,7 +372,7 @@ impl NaryFlat<BitVector> for BvNaryOp {
|
||||
if c.uint() == &Integer::from(0) {
|
||||
leaf_term(Op::Const(Value::BitVector(c)))
|
||||
} else {
|
||||
if c.uint() != &Integer::from(1) || children.len() == 0 {
|
||||
if c.uint() != &Integer::from(1) || children.is_empty() {
|
||||
children.push(leaf_term(Op::Const(Value::BitVector(c))));
|
||||
}
|
||||
safe_nary(BV_MUL, children)
|
||||
@@ -397,7 +397,7 @@ impl NaryFlat<FieldElem> for PfNaryOp {
|
||||
PfNaryOp::Add => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::Add::add);
|
||||
if c.i() != &Integer::from(0) || children.len() == 0 {
|
||||
if c.i() != &Integer::from(0) || children.is_empty() {
|
||||
children.push(leaf_term(Op::Const(Value::Field(c))));
|
||||
}
|
||||
}
|
||||
@@ -406,7 +406,7 @@ impl NaryFlat<FieldElem> for PfNaryOp {
|
||||
PfNaryOp::Mul => {
|
||||
if let Some(c) = consts.pop() {
|
||||
let c = consts.into_iter().fold(c, std::ops::Mul::mul);
|
||||
if c.i() == &Integer::from(0) || children.len() == 0 {
|
||||
if c.i() == &Integer::from(0) || children.is_empty() {
|
||||
leaf_term(Op::Const(Value::Field(c)))
|
||||
} else {
|
||||
if c.i() != &Integer::from(1) {
|
||||
|
||||
@@ -23,7 +23,7 @@ impl<T: Clone> PersistentConcatList<T> {
|
||||
match &*t {
|
||||
PersistentConcatList::Leaf(t) => v.push((**t).clone()),
|
||||
PersistentConcatList::Concat(ts) => {
|
||||
ts.into_iter().for_each(|c| stack.push(c.clone()));
|
||||
ts.iter().for_each(|c| stack.push(c.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ impl<T: Clone> PersistentConcatList<T> {
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
fn to_term(&mut self) -> Term {
|
||||
fn as_term(&mut self) -> Term {
|
||||
match self {
|
||||
Entry::Term(t) => (**t).clone(),
|
||||
Entry::NaryTerm(o, ts, maybe_term) => {
|
||||
@@ -47,6 +47,7 @@ impl Entry {
|
||||
}
|
||||
|
||||
/// Flattening cache.
|
||||
#[derive(Default)]
|
||||
pub struct Cache(TermMap<Entry>);
|
||||
|
||||
impl Cache {
|
||||
@@ -99,7 +100,7 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache
|
||||
}
|
||||
e => {
|
||||
children
|
||||
.push(Rc::new(PersistentConcatList::Leaf(Rc::new(e.to_term()))));
|
||||
.push(Rc::new(PersistentConcatList::Leaf(Rc::new(e.as_term()))));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,13 +113,13 @@ pub fn flatten_nary_ops_cached(term_: Term, Cache(ref mut rewritten): &mut Cache
|
||||
_ => Entry::Term(Rc::new(term(
|
||||
t.op.clone(),
|
||||
t.cs.iter()
|
||||
.map(|c| rewritten.get_mut(c).unwrap().to_term())
|
||||
.map(|c| rewritten.get_mut(c).unwrap().as_term())
|
||||
.collect(),
|
||||
))),
|
||||
};
|
||||
rewritten.insert(t, entry);
|
||||
}
|
||||
rewritten.get_mut(&term_).unwrap().to_term()
|
||||
rewritten.get_mut(&term_).unwrap().as_term()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -63,7 +63,7 @@ impl<'a> Inliner<'a> {
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
self.stale_vars.contains(&key),
|
||||
self.stale_vars.contains(key),
|
||||
"Variable {}, which is being susbstituted",
|
||||
key
|
||||
);
|
||||
@@ -107,7 +107,7 @@ impl<'a> Inliner<'a> {
|
||||
///
|
||||
/// Will not return `v` which are protected.
|
||||
fn as_fresh_def(&self, t: &Term) -> Option<(Term, Term)> {
|
||||
if &EQ == &t.op {
|
||||
if EQ == t.op {
|
||||
if let Op::Var(name, _) = &t.cs[0].op {
|
||||
if !self.stale_vars.contains(&t.cs[0])
|
||||
&& !self.protected.contains(name)
|
||||
@@ -133,7 +133,7 @@ impl<'a> Inliner<'a> {
|
||||
///
|
||||
/// If `t` is not a substitution, then its (substituted variant) is returned.
|
||||
fn ingest_term(&mut self, t: &Term) -> Option<Term> {
|
||||
if let Some((var, val)) = self.as_fresh_def(&t) {
|
||||
if let Some((var, val)) = self.as_fresh_def(t) {
|
||||
//debug!(target: "circ::ir::opt::inline", "Inline: {} -> {}", var, val.clone());
|
||||
// Rewrite the substitution
|
||||
let subst_val = self.apply(&val);
|
||||
|
||||
@@ -223,7 +223,7 @@ impl RewritePass for Replacer {
|
||||
}
|
||||
}
|
||||
Op::Store => {
|
||||
if self.should_replace(&orig) {
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
@@ -233,7 +233,7 @@ impl RewritePass for Replacer {
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
if self.should_replace(&orig) {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -64,7 +64,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
|
||||
let mut new_outputs = Vec::new();
|
||||
for a in std::mem::take(&mut cs.outputs) {
|
||||
assert_eq!(check(&a), Sort::Bool, "Non-bool in {:?}", i);
|
||||
if &a.op == &Op::BoolNaryOp(BoolNaryOp::And) {
|
||||
if a.op == Op::BoolNaryOp(BoolNaryOp::And) {
|
||||
new_outputs.extend(a.cs.iter().cloned());
|
||||
} else {
|
||||
new_outputs.push(a)
|
||||
|
||||
@@ -95,7 +95,7 @@ impl RewritePass for Pass {
|
||||
party_visibility,
|
||||
&mut new_var_reqs,
|
||||
);
|
||||
if new_var_reqs.len() > 0 {
|
||||
if !new_var_reqs.is_empty() {
|
||||
computation.replace_input(orig.clone(), new_var_reqs);
|
||||
}
|
||||
Some(new)
|
||||
|
||||
@@ -18,8 +18,8 @@ pub fn sha_rewrites(term_: &Term) -> Term {
|
||||
if t.cs.len() == 2 {
|
||||
let a = get(0);
|
||||
let b = get(1);
|
||||
if &a.op == &b.op
|
||||
&& &a.op == &BV_AND
|
||||
if a.op == b.op
|
||||
&& a.op == BV_AND
|
||||
&& b.cs[0].op == BV_NOT
|
||||
&& b.cs[0].cs[0] == a.cs[0]
|
||||
{
|
||||
@@ -46,9 +46,9 @@ pub fn sha_rewrites(term_: &Term) -> Term {
|
||||
let c0 = get(0);
|
||||
let c1 = get(1);
|
||||
let c2 = get(2);
|
||||
if &c0.op == &c1.op
|
||||
&& &c1.op == &c2.op
|
||||
&& &c2.op == &BV_AND
|
||||
if c0.op == c1.op
|
||||
&& c1.op == c2.op
|
||||
&& c2.op == BV_AND
|
||||
&& c0.cs.len() == 2
|
||||
&& c1.cs.len() == 2
|
||||
&& c2.cs.len() == 2
|
||||
@@ -62,7 +62,7 @@ pub fn sha_rewrites(term_: &Term) -> Term {
|
||||
{
|
||||
debug!("SHA MAJ");
|
||||
let items = s0.union(&s1).collect::<Vec<_>>();
|
||||
let w = check(&c0).as_bv();
|
||||
let w = check(c0).as_bv();
|
||||
Some(term(
|
||||
BV_CONCAT,
|
||||
(0..w)
|
||||
@@ -95,7 +95,7 @@ pub fn sha_rewrites(term_: &Term) -> Term {
|
||||
});
|
||||
cache.insert(t, new_t);
|
||||
}
|
||||
cache.get(&term_).unwrap().clone()
|
||||
cache.get(term_).unwrap().clone()
|
||||
}
|
||||
|
||||
/// Eliminate the SHA majority operator, replacing it with ands and ors.
|
||||
@@ -105,9 +105,9 @@ pub fn sha_maj_elim(term_: &Term) -> Term {
|
||||
for t in PostOrderIter::new(term_.clone()) {
|
||||
let c_get = |x: &Term| cache.get(x).unwrap();
|
||||
let get = |i: usize| c_get(&t.cs[i]);
|
||||
let new_t = match &t.op {
|
||||
let new_t = match t.op {
|
||||
// maj(a, b, c) = (a & b) | (b & c) | (c & a)
|
||||
&Op::BoolMaj => {
|
||||
Op::BoolMaj => {
|
||||
let a = get(0);
|
||||
let b = get(1);
|
||||
let c = get(2);
|
||||
@@ -126,7 +126,7 @@ pub fn sha_maj_elim(term_: &Term) -> Term {
|
||||
});
|
||||
cache.insert(t, new_t);
|
||||
}
|
||||
cache.get(&term_).unwrap().clone()
|
||||
cache.get(term_).unwrap().clone()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -68,7 +68,7 @@ impl TupleTree {
|
||||
fn flatten(&self) -> impl Iterator<Item = Term> {
|
||||
let mut out = Vec::new();
|
||||
fn rec_unroll_into(t: &Term, out: &mut Vec<Term>) {
|
||||
if &t.op == &Op::Tuple {
|
||||
if t.op == Op::Tuple {
|
||||
for c in &t.cs {
|
||||
rec_unroll_into(c, out);
|
||||
}
|
||||
@@ -81,7 +81,7 @@ impl TupleTree {
|
||||
}
|
||||
fn structure(&self, flattened: impl IntoIterator<Item = Term>) -> Self {
|
||||
fn term_structure(t: &Term, iter: &mut impl Iterator<Item = Term>) -> Term {
|
||||
if &t.op == &Op::Tuple {
|
||||
if t.op == Op::Tuple {
|
||||
term(
|
||||
Op::Tuple,
|
||||
t.cs.iter().map(|c| term_structure(c, iter)).collect(),
|
||||
@@ -94,9 +94,9 @@ impl TupleTree {
|
||||
}
|
||||
fn well_formed(&self) -> bool {
|
||||
for t in PostOrderIter::new(self.0.clone()) {
|
||||
if &t.op != &Op::Tuple {
|
||||
if t.op != Op::Tuple {
|
||||
for c in &t.cs {
|
||||
if &c.op == &Op::Tuple {
|
||||
if c.op == Op::Tuple {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -267,11 +267,9 @@ impl RewritePass for TupleLifter {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn tuple_free(t: Term) -> bool {
|
||||
!PostOrderIter::new(t).any(|c| match check(&c) {
|
||||
Sort::Tuple(_) => true,
|
||||
_ => false,
|
||||
})
|
||||
PostOrderIter::new(t).all(|c| !matches!(check(&c), Sort::Tuple(..)))
|
||||
}
|
||||
|
||||
/// Run the tuple elimination pass.
|
||||
|
||||
@@ -45,11 +45,8 @@ impl Constraints for Computation {
|
||||
let mut set = FxHashSet::default();
|
||||
for a in &assertions {
|
||||
for t in PostOrderIter::new(a.clone()) {
|
||||
match &t.op {
|
||||
Op::Var(_, _) => {
|
||||
set.insert(t.clone());
|
||||
}
|
||||
_ => {}
|
||||
if let Op::Var(..) = t.op {
|
||||
set.insert(t.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,8 +122,8 @@ impl FixedSizeDist {
|
||||
Op::Var(self.sample_ident(&format!("bv{}", w), rng), sort.clone()),
|
||||
Op::BvUnOp(BvUnOp::Neg),
|
||||
Op::BvUnOp(BvUnOp::Not),
|
||||
Op::BvUext(rng.gen_range(0..w.clone())),
|
||||
Op::BvSext(rng.gen_range(0..w.clone())),
|
||||
Op::BvUext(rng.gen_range(0..*w)),
|
||||
Op::BvSext(rng.gen_range(0..*w)),
|
||||
Op::BvBinOp(BvBinOp::Sub),
|
||||
Op::BvBinOp(BvBinOp::Udiv),
|
||||
Op::BvBinOp(BvBinOp::Urem),
|
||||
@@ -292,7 +292,7 @@ impl rand::distributions::Distribution<Term> for FixedSizeDist {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
|
||||
let op = self.sample_op(&self.sort, rng);
|
||||
let sorts = self.sample_child_sorts(&self.sort, &op, rng);
|
||||
if sorts.len() == 0 {
|
||||
if sorts.is_empty() {
|
||||
leaf_term(op)
|
||||
} else {
|
||||
let mut dists: Vec<FixedSizeDist> = sorts
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
//! Extra algorithms over terms (e.g. substitutions)
|
||||
|
||||
use super::*;
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
/// Convert `t` to width `w`, though unsigned extension or extraction
|
||||
pub fn to_width(t: &Term, w: usize) -> Term {
|
||||
let old_w = check(t).as_bv();
|
||||
if old_w < w {
|
||||
term(Op::BvUext(w - old_w), vec![t.clone()])
|
||||
} else if old_w == w {
|
||||
t.clone()
|
||||
} else {
|
||||
term(Op::BvExtract(w - 1, 0), vec![t.clone()])
|
||||
match old_w.cmp(&w) {
|
||||
Ordering::Less => term(Op::BvUext(w - old_w), vec![t.clone()]),
|
||||
Ordering::Equal => t.clone(),
|
||||
Ordering::Greater => term(Op::BvExtract(w - 1, 0), vec![t.clone()]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +44,7 @@ impl Display for Letified {
|
||||
|
||||
writeln!(f, "(let (")?;
|
||||
for t in PostOrderIter::new(self.0.clone()) {
|
||||
if parent_counts.get(&t).unwrap_or(&0) > &1 && t.cs.len() > 0 {
|
||||
if parent_counts.get(&t).unwrap_or(&0) > &1 && !t.cs.is_empty() {
|
||||
let name = format!("let_{}", let_ct);
|
||||
let_ct += 1;
|
||||
let sort = check(&t);
|
||||
@@ -131,7 +130,7 @@ pub fn free_in(v: &str, t: Term) -> bool {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
false
|
||||
}
|
||||
|
||||
/// If this term is a constant field or bit-vector, get the unsigned int value.
|
||||
|
||||
@@ -37,7 +37,7 @@ impl FieldElem {
|
||||
location
|
||||
);
|
||||
debug_assert!(
|
||||
&self.i <= &*self.modulus,
|
||||
self.i <= *self.modulus,
|
||||
"Too small field elem: {}\n at {}",
|
||||
self,
|
||||
location
|
||||
|
||||
@@ -708,11 +708,17 @@ impl Display for Array {
|
||||
}
|
||||
|
||||
impl std::cmp::Eq for Value {}
|
||||
// We walk in danger here, intentionally. One day we may fix it.
|
||||
// FP is the heart of the problem.
|
||||
#[allow(clippy::derive_ord_xor_partial_ord)]
|
||||
impl std::cmp::Ord for Value {
|
||||
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(o).expect("broken Value cmp")
|
||||
}
|
||||
}
|
||||
// We walk in danger here, intentionally. One day we may fix it.
|
||||
// FP is the heart of the problem.
|
||||
#[allow(clippy::derive_hash_xor_eq)]
|
||||
impl std::hash::Hash for Value {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
@@ -778,7 +784,7 @@ impl Sort {
|
||||
/// Unwrap the constituent sorts of this tuple, panicking otherwise.
|
||||
pub fn as_tuple(&self) -> &Vec<Sort> {
|
||||
if let Sort::Tuple(w) = self {
|
||||
&w
|
||||
w
|
||||
} else {
|
||||
panic!("{} is not a tuple", self)
|
||||
}
|
||||
@@ -815,7 +821,7 @@ impl Sort {
|
||||
Box::new(
|
||||
std::iter::successors(Some(Integer::from(0)), move |p| {
|
||||
let q = p.clone() + 1;
|
||||
if &q < &*m {
|
||||
if q < *m {
|
||||
Some(q)
|
||||
} else {
|
||||
None
|
||||
@@ -927,21 +933,18 @@ impl TermTable {
|
||||
if val.elm.upgrade().is_some() {
|
||||
true
|
||||
} else {
|
||||
to_check.extend(key.cs.iter().map(|i| i.clone()));
|
||||
to_check.extend(key.cs.iter().cloned());
|
||||
false
|
||||
}
|
||||
});
|
||||
while let Some(t) = to_check.pop() {
|
||||
let data: TermData = (*t).clone();
|
||||
std::mem::drop(t);
|
||||
match self.map.entry(data) {
|
||||
std::collections::hash_map::Entry::Occupied(e) => {
|
||||
if e.get().elm.upgrade().is_none() {
|
||||
let (key, _val) = e.remove_entry();
|
||||
to_check.extend(key.cs.iter().map(|i| i.clone()));
|
||||
}
|
||||
if let std::collections::hash_map::Entry::Occupied(e) = self.map.entry(data) {
|
||||
if e.get().elm.upgrade().is_none() {
|
||||
let (key, _val) = e.remove_entry();
|
||||
to_check.extend(key.cs.iter().cloned());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let new_size = self.map.len();
|
||||
@@ -1009,19 +1012,11 @@ impl TermData {
|
||||
}
|
||||
/// Is this a variable?
|
||||
pub fn is_var(&self) -> bool {
|
||||
if let Op::Var(..) = &self.op {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
matches!(&self.op, Op::Var(..))
|
||||
}
|
||||
/// Is this a value
|
||||
pub fn is_const(&self) -> bool {
|
||||
if let Op::Const(..) = &self.op {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
matches!(&self.op, Op::Const(..))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1085,7 +1080,7 @@ impl Value {
|
||||
/// Unwrap the constituent value of this array, panicking otherwise.
|
||||
pub fn as_array(&self) -> &Array {
|
||||
if let Value::Array(w) = self {
|
||||
&w
|
||||
w
|
||||
} else {
|
||||
panic!("{} is not an aray", self)
|
||||
}
|
||||
@@ -1377,7 +1372,7 @@ impl std::iter::Iterator for PostOrderIter {
|
||||
type Item = Term;
|
||||
fn next(&mut self) -> Option<Term> {
|
||||
while let Some((children_pushed, t)) = self.stack.last() {
|
||||
if self.visited.contains(&t) {
|
||||
if self.visited.contains(t) {
|
||||
self.stack.pop();
|
||||
} else if !children_pushed {
|
||||
self.stack.last_mut().unwrap().0 = true;
|
||||
@@ -1428,7 +1423,7 @@ impl ComputationMetadata {
|
||||
party,
|
||||
self.input_vis.get(&input_name).unwrap()
|
||||
);
|
||||
self.input_vis.insert(input_name.clone(), party);
|
||||
self.input_vis.insert(input_name, party);
|
||||
self.inputs.push(term);
|
||||
}
|
||||
/// Replace the `original` computation input with `new`, in the order given.
|
||||
@@ -1473,10 +1468,10 @@ impl ComputationMetadata {
|
||||
}
|
||||
/// Returns None if the value is public. Otherwise, the unique party that knows it.
|
||||
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
|
||||
self.input_vis
|
||||
*self
|
||||
.input_vis
|
||||
.get(input_name)
|
||||
.unwrap_or_else(|| panic!("Missing input {} in inputs{:#?}", input_name, self.inputs))
|
||||
.clone()
|
||||
}
|
||||
/// Is this input public?
|
||||
pub fn is_input(&self, input_name: &str) -> bool {
|
||||
@@ -1497,6 +1492,9 @@ impl ComputationMetadata {
|
||||
})
|
||||
}
|
||||
/// Get all public inputs.
|
||||
// I think the lint is just broken here.
|
||||
// TODO: submit a patch
|
||||
#[allow(clippy::needless_lifetimes)]
|
||||
pub fn public_inputs<'a>(&'a self) -> impl Iterator<Item = Term> + 'a {
|
||||
self.inputs.iter().filter_map(move |input| {
|
||||
if let Op::Var(name, _) = &input.op {
|
||||
@@ -1513,7 +1511,7 @@ impl ComputationMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// An IR computation.
|
||||
pub struct Computation {
|
||||
/// The outputs of the computation.
|
||||
@@ -1526,16 +1524,6 @@ pub struct Computation {
|
||||
pub metadata: ComputationMetadata,
|
||||
}
|
||||
|
||||
impl std::default::Default for Computation {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
outputs: Vec::new(),
|
||||
metadata: ComputationMetadata::default(),
|
||||
values: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Computation {
|
||||
/// Create a new variable, `name: s`, where `val_fn` can be called to get the concrete value,
|
||||
/// and `public` indicates whether this variable is public in the constraint system.
|
||||
|
||||
@@ -41,7 +41,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
Op::BvConcat => t
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| check_raw(c))
|
||||
.map(check_raw)
|
||||
.try_fold(
|
||||
Ok(0),
|
||||
|l: Result<usize, TypeErrorReason>,
|
||||
@@ -88,9 +88,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
Op::Select => array_or(&check_raw(&t.cs[0])?, "select").map(|(_, v)| v.clone()),
|
||||
Op::Store => Ok(check_raw(&t.cs[0])?),
|
||||
Op::Tuple => Ok(Sort::Tuple(
|
||||
t.cs.iter()
|
||||
.map(|c| check_raw(c))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
t.cs.iter().map(check_raw).collect::<Result<Vec<_>, _>>()?,
|
||||
)),
|
||||
Op::Field(i) => {
|
||||
let sort = check_raw(&t.cs[0])?;
|
||||
@@ -127,20 +125,17 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
let mut term_tys = TERM_TYPES.write().unwrap();
|
||||
// to_check is a stack of (node, cs checked) pairs.
|
||||
let mut to_check = vec![(t.clone(), false)];
|
||||
while to_check.len() > 0 {
|
||||
while !to_check.is_empty() {
|
||||
let back = to_check.last_mut().unwrap();
|
||||
let weak = back.0.to_weak();
|
||||
// The idea here is to check that
|
||||
match term_tys.get_key_value(&weak) {
|
||||
Some((p, _)) => {
|
||||
if p.to_hconsed().is_some() {
|
||||
to_check.pop();
|
||||
continue;
|
||||
} else {
|
||||
term_tys.remove(&weak);
|
||||
}
|
||||
if let Some((p, _)) = term_tys.get_key_value(&weak) {
|
||||
if p.to_hconsed().is_some() {
|
||||
to_check.pop();
|
||||
continue;
|
||||
} else {
|
||||
term_tys.remove(&weak);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
if !back.1 {
|
||||
back.1 = true;
|
||||
@@ -173,7 +168,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
}
|
||||
(Op::BvNaryOp(_), a) => {
|
||||
let ctx = "bv nary op";
|
||||
all_eq_or(a.into_iter().cloned(), ctx)
|
||||
all_eq_or(a.iter().cloned(), ctx)
|
||||
.and_then(|t| bv_or(t, ctx))
|
||||
.map(|a| a.clone())
|
||||
}
|
||||
@@ -207,7 +202,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
}
|
||||
(Op::BoolNaryOp(_), a) => {
|
||||
let ctx = "bool nary op";
|
||||
all_eq_or(a.into_iter().cloned(), ctx)
|
||||
all_eq_or(a.iter().cloned(), ctx)
|
||||
.and_then(|t| bool_or(t, ctx))
|
||||
.map(|a| a.clone())
|
||||
}
|
||||
@@ -252,7 +247,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
(Op::FpToFp(32), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F32),
|
||||
(Op::PfNaryOp(_), a) => {
|
||||
let ctx = "pf nary op";
|
||||
all_eq_or(a.into_iter().cloned(), ctx)
|
||||
all_eq_or(a.iter().cloned(), ctx)
|
||||
.and_then(|t| pf_or(t, ctx))
|
||||
.map(|a| a.clone())
|
||||
}
|
||||
@@ -264,9 +259,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
(Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store")
|
||||
.and_then(|_| eq_or(v, b, "store"))
|
||||
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
|
||||
(Op::Tuple, a) => {
|
||||
Ok(Sort::Tuple(a.into_iter().map(|a| (*a).clone()).collect()))
|
||||
}
|
||||
(Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())),
|
||||
(Op::Field(i), &[a]) => tuple_or(a, "tuple field access").and_then(|t| {
|
||||
if i < &t.len() {
|
||||
Ok(t[*i].clone())
|
||||
@@ -288,7 +281,7 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
)))
|
||||
}
|
||||
}),
|
||||
(_, _) => Err(TypeErrorReason::Custom(format!("other"))),
|
||||
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
|
||||
})
|
||||
.map_err(|reason| TypeError {
|
||||
op: back.0.op.clone(),
|
||||
@@ -357,7 +350,7 @@ fn array_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<(&'a Sort, &'a Sort),
|
||||
}
|
||||
|
||||
fn bool_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
|
||||
if let &Sort::Bool = a {
|
||||
if let Sort::Bool = a {
|
||||
Ok(a)
|
||||
} else {
|
||||
Err(TypeErrorReason::ExpectedBool(a.clone(), ctx))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
//! A compiler infrastructure for compiling programs to circuits
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![deny(warnings)]
|
||||
|
||||
#[macro_use]
|
||||
pub mod ir;
|
||||
|
||||
@@ -109,7 +109,7 @@ impl CostModel {
|
||||
};
|
||||
for (op_name, json) in obj {
|
||||
// HACK: assumes the presence of 2 partitions names into conversion and otherwise.
|
||||
if !op_name.contains("2") {
|
||||
if !op_name.contains('2') {
|
||||
for op in ops_from_name(op_name) {
|
||||
let obj = json.as_object().unwrap();
|
||||
for (share_type, share_name) in
|
||||
@@ -117,7 +117,7 @@ impl CostModel {
|
||||
{
|
||||
if let Some(cost) = get_cost_opt(share_name, obj) {
|
||||
ops.entry(op.clone())
|
||||
.or_insert_with(|| FxHashMap::default())
|
||||
.or_insert_with(FxHashMap::default)
|
||||
.insert(*share_type, cost);
|
||||
}
|
||||
}
|
||||
@@ -160,29 +160,27 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
|
||||
// build variables for all term assignments
|
||||
for (t, i) in terms.iter() {
|
||||
let mut vars = vec![];
|
||||
if let Op::Var(_, _) = &t.op {
|
||||
for ty in &SHARE_TYPES {
|
||||
let name = format!("t_{}_{}", i, ty.char());
|
||||
let v = ilp.new_variable(variable().binary(), name.clone());
|
||||
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
|
||||
vars.push(v);
|
||||
match &t.op {
|
||||
Op::Var(..) | Op::Const(_) => {
|
||||
for ty in &SHARE_TYPES {
|
||||
let name = format!("t_{}_{}", i, ty.char());
|
||||
let v = ilp.new_variable(variable().binary(), name.clone());
|
||||
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
|
||||
vars.push(v);
|
||||
}
|
||||
}
|
||||
} else if let Op::Const(_) = &t.op {
|
||||
for ty in &SHARE_TYPES {
|
||||
let name = format!("t_{}_{}", i, ty.char());
|
||||
let v = ilp.new_variable(variable().binary(), name.clone());
|
||||
term_vars.insert((t.clone(), *ty), (v, 0.0, name));
|
||||
vars.push(v);
|
||||
_ => {
|
||||
if let Some(costs) = costs.ops.get(&t.op) {
|
||||
for (ty, cost) in costs {
|
||||
let name = format!("t_{}_{}", i, ty.char());
|
||||
let v = ilp.new_variable(variable().binary(), name.clone());
|
||||
term_vars.insert((t.clone(), *ty), (v, *cost, name));
|
||||
vars.push(v);
|
||||
}
|
||||
} else {
|
||||
panic!("No cost for op {}", &t.op)
|
||||
}
|
||||
}
|
||||
} else if let Some(costs) = costs.ops.get(&t.op) {
|
||||
for (ty, cost) in costs {
|
||||
let name = format!("t_{}_{}", i, ty.char());
|
||||
let v = ilp.new_variable(variable().binary(), name.clone());
|
||||
term_vars.insert((t.clone(), *ty), (v, *cost, name));
|
||||
vars.push(v);
|
||||
}
|
||||
} else {
|
||||
panic!("No cost for op {}", &t.op)
|
||||
}
|
||||
// Sum of assignments is at least 1.
|
||||
ilp.new_constraint(
|
||||
@@ -218,7 +216,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
|
||||
let def_uses: FxHashMap<Term, Vec<Term>> = {
|
||||
let mut t = FxHashMap::default();
|
||||
for (d, u) in def_uses {
|
||||
t.entry(d).or_insert_with(|| Vec::new()).push(u);
|
||||
t.entry(d).or_insert_with(Vec::new).push(u);
|
||||
}
|
||||
t
|
||||
};
|
||||
@@ -232,7 +230,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
|
||||
// c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1
|
||||
term_vars
|
||||
.get(&(use_.clone(), *to_ty))
|
||||
.map(|t_to| ilp.new_constraint(c.0 >> t_from.0 + t_to.0 - 1.0))
|
||||
.map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)))
|
||||
})
|
||||
});
|
||||
}
|
||||
@@ -245,9 +243,7 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
|
||||
.values()
|
||||
.map(|(a, b)| (a, b))
|
||||
.chain(term_vars.values().map(|(a, b, _)| (a, b)))
|
||||
.fold(0.0.into(), |acc: Expression, (v, cost)| {
|
||||
acc + v.clone() * *cost
|
||||
}),
|
||||
.fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost),
|
||||
);
|
||||
|
||||
let (_opt, solution) = ilp.default_solve().unwrap();
|
||||
|
||||
@@ -21,9 +21,9 @@ pub const SHARE_TYPES: [ShareType; 3] = [ShareType::Arithmetic, ShareType::Boole
|
||||
impl ShareType {
|
||||
fn char(&self) -> char {
|
||||
match self {
|
||||
&ShareType::Arithmetic => 'a',
|
||||
&ShareType::Yao => 'y',
|
||||
&ShareType::Boolean => 'b',
|
||||
ShareType::Arithmetic => 'a',
|
||||
ShareType::Yao => 'y',
|
||||
ShareType::Boolean => 'b',
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ pub fn all_boolean_sharing(c: &Computation) -> SharingMap {
|
||||
c.outputs
|
||||
.iter()
|
||||
.flat_map(|output| {
|
||||
PostOrderIter::new(output.clone()).map(|term| (term.clone(), ShareType::Boolean))
|
||||
PostOrderIter::new(output.clone()).map(|term| (term, ShareType::Boolean))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ pub struct ABY {
|
||||
output: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for ABY {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ABY {
|
||||
/// Initialize ABY circuit
|
||||
pub fn new() -> Self {
|
||||
|
||||
@@ -19,16 +19,16 @@ fn get_filename(path_buf: PathBuf) -> String {
|
||||
|
||||
/// In ABY examples, remove the existing directory and create a directory
|
||||
/// in order to write the new test case
|
||||
fn create_dir_in_aby(filename: &String) {
|
||||
let path = format!("third_party/ABY/src/examples/{}", *filename);
|
||||
let _ = fs::remove_dir_all(path.clone());
|
||||
fs::create_dir_all(format!("{}/common", path.clone())).expect("Failed to create directory");
|
||||
fn create_dir_in_aby(filename: &str) {
|
||||
let path = format!("third_party/ABY/src/examples/{}", filename);
|
||||
let _ = fs::remove_dir_all(&path);
|
||||
fs::create_dir_all(format!("{}/common", path)).expect("Failed to create directory");
|
||||
}
|
||||
|
||||
/// Update the CMake file in ABY
|
||||
fn update_cmake_file(filename: &String) {
|
||||
fn update_cmake_file(filename: &str) {
|
||||
let cmake_filename = "third_party/ABY/src/examples/CMakeLists.txt";
|
||||
let file = File::open(cmake_filename.clone()).expect("Failed to open cmake file");
|
||||
let file = File::open(&cmake_filename).expect("Failed to open cmake file");
|
||||
let reader = BufReader::new(file);
|
||||
let mut flag = false;
|
||||
|
||||
@@ -46,70 +46,68 @@ fn update_cmake_file(filename: &String) {
|
||||
.open(cmake_filename)
|
||||
.unwrap();
|
||||
|
||||
writeln!(file, "{}", format!("add_subdirectory({})", *filename))
|
||||
writeln!(file, "{}", format!("add_subdirectory({})", filename))
|
||||
.expect("Failed to write to cmake file");
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a CMake file for the corresponding filename (testcase)
|
||||
/// in the ABY examples directory
|
||||
fn write_test_cmake_file(filename: &String) {
|
||||
let path = format!("third_party/ABY/src/examples/{}/CMakeLists.txt", *filename);
|
||||
fn write_test_cmake_file(filename: &str) {
|
||||
let path = format!("third_party/ABY/src/examples/{}/CMakeLists.txt", filename);
|
||||
|
||||
fs::write(
|
||||
path.clone(),
|
||||
&path,
|
||||
format!(
|
||||
concat!(
|
||||
"add_executable({}_test {}_test.cpp common/{}.cpp)\n",
|
||||
"target_link_libraries({}_test ABY::aby ENCRYPTO_utils::encrypto_utils)"
|
||||
),
|
||||
*filename, *filename, *filename, *filename
|
||||
filename, filename, filename, filename
|
||||
),
|
||||
)
|
||||
.expect("Failed to write to cmake file");
|
||||
}
|
||||
|
||||
/// Write the testcase in the ABY examples directory
|
||||
fn write_test_file(filename: &String) {
|
||||
fn write_test_file(filename: &str) {
|
||||
let template = fs::read_to_string("third_party/ABY_templates/test_template.txt")
|
||||
.expect("Unable to read file");
|
||||
let path = format!(
|
||||
"third_party/ABY/src/examples/{}/{}_test.cpp",
|
||||
*filename, *filename
|
||||
filename, filename
|
||||
);
|
||||
|
||||
fs::write(path.clone(), template.replace("{fn}", &*filename))
|
||||
.expect("Failed to write to test file");
|
||||
fs::write(&path, template.replace("{fn}", filename)).expect("Failed to write to test file");
|
||||
}
|
||||
|
||||
/// Using the h_template.txt, write the .h file for the new test case
|
||||
fn write_h_file(filename: &String) {
|
||||
fn write_h_file(filename: &str) {
|
||||
let template = fs::read_to_string("third_party/ABY_templates/h_template.txt")
|
||||
.expect("Unable to read file");
|
||||
let path = format!(
|
||||
"third_party/ABY/src/examples/{}/common/{}.h",
|
||||
*filename, *filename
|
||||
filename, filename
|
||||
);
|
||||
|
||||
fs::write(path.clone(), template.replace("{fn}", &*filename))
|
||||
.expect("Failed to write to h file");
|
||||
fs::write(&path, template.replace("{fn}", &*filename)).expect("Failed to write to h file");
|
||||
}
|
||||
|
||||
/// Using the cpp_template.txt, write the .cpp file for the new test case
|
||||
fn write_circ_file(filename: &String, circ: String, output: &String) {
|
||||
fn write_circ_file(filename: &str, circ: &str, output: &str) {
|
||||
let template = fs::read_to_string("third_party/ABY_templates/cpp_template.txt")
|
||||
.expect("Unable to read file");
|
||||
let path = format!(
|
||||
"third_party/ABY/src/examples/{}/common/{}.cpp",
|
||||
*filename, *filename
|
||||
filename, filename
|
||||
);
|
||||
|
||||
fs::write(
|
||||
path.clone(),
|
||||
&path,
|
||||
template
|
||||
.replace("{fn}", &*filename)
|
||||
.replace("{circ}", &circ)
|
||||
.replace("{output}", &output),
|
||||
.replace("{fn}", filename)
|
||||
.replace("{circ}", circ)
|
||||
.replace("{output}", output),
|
||||
)
|
||||
.expect("Failed to write to cpp file");
|
||||
}
|
||||
@@ -123,5 +121,5 @@ pub fn write_aby_exec(aby: ABY, path_buf: PathBuf) {
|
||||
write_test_file(&filename);
|
||||
write_h_file(&filename);
|
||||
let circ_str = aby.setup.join("\n\t") + &aby.circs.join("\n\t");
|
||||
write_circ_file(&filename, circ_str, &aby.output.join("\n\t"));
|
||||
write_circ_file(&filename, &circ_str, &aby.output.join("\n\t"));
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ impl ToABY {
|
||||
md: metadata,
|
||||
inputs: TermMap::new(),
|
||||
cache: TermMap::new(),
|
||||
s_map: s_map,
|
||||
s_map,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ impl ToABY {
|
||||
|
||||
/// Parse variable name from IR representation of a variable
|
||||
fn parse_var_name(&self, full_name: String) -> String {
|
||||
let parsed: Vec<String> = full_name.split("_").map(str::to_string).collect();
|
||||
let parsed: Vec<String> = full_name.split('_').map(str::to_string).collect();
|
||||
if parsed.len() < 2 {
|
||||
panic!("Invalid variable name: {}", full_name);
|
||||
}
|
||||
@@ -90,28 +90,26 @@ impl ToABY {
|
||||
|
||||
fn add_cons_gate(&self, t: Term) -> String {
|
||||
let name = ToABY::get_var_name(t.clone());
|
||||
let s_circ = self.get_sharetype_circ(t.clone());
|
||||
let s_circ = self.get_sharetype_circ(t);
|
||||
format!(
|
||||
"s_{} = {}->PutCONSGate((uint64_t){}, bitlen);",
|
||||
name, s_circ, name
|
||||
)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn add_in_gate(&self, t: Term, role: String) -> String {
|
||||
let name = ToABY::get_var_name(t.clone());
|
||||
let s_circ = self.get_sharetype_circ(t.clone());
|
||||
let s_circ = self.get_sharetype_circ(t);
|
||||
format!(
|
||||
"\ts_{} = {}->PutINGate({}, bitlen, {});",
|
||||
name, s_circ, name, role
|
||||
)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn add_dummy_gate(&self, t: Term) -> String {
|
||||
let name = ToABY::get_var_name(t.clone());
|
||||
let s_circ = self.get_sharetype_circ(t.clone());
|
||||
format!("\ts_{} = {}->PutDummyINGate(bitlen);", name, s_circ).to_string()
|
||||
let s_circ = self.get_sharetype_circ(t);
|
||||
format!("\ts_{} = {}->PutDummyINGate(bitlen);", name, s_circ)
|
||||
}
|
||||
|
||||
/// Initialize private and public inputs from each party
|
||||
@@ -177,13 +175,13 @@ impl ToABY {
|
||||
|
||||
/// Return constant gate evaluating to 0
|
||||
#[allow(dead_code)]
|
||||
fn zero() -> String {
|
||||
format!("bcirc->PutCONSGate((uint64_t)0, (uint32_t)1)")
|
||||
fn zero() -> &'static str {
|
||||
"bcirc->PutCONSGate((uint64_t)0, (uint32_t)1)"
|
||||
}
|
||||
|
||||
/// Return constant gate evaluating to 1
|
||||
fn one() -> String {
|
||||
format!("bcirc->PutCONSGate((uint64_t)1, (uint32_t)1)")
|
||||
fn one() -> &'static str {
|
||||
"bcirc->PutCONSGate((uint64_t)1, (uint32_t)1)"
|
||||
}
|
||||
|
||||
fn remove_cons_gate(&self, circ: String) -> String {
|
||||
@@ -191,7 +189,7 @@ impl ToABY {
|
||||
circ.split("PutCONSGate(")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
.split(",")
|
||||
.split(',')
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
@@ -204,8 +202,8 @@ impl ToABY {
|
||||
let s_circ = self.get_sharetype_circ(t.clone());
|
||||
match check(&a) {
|
||||
Sort::Bool => {
|
||||
let a_circ = self.get_bool(&a).clone();
|
||||
let b_circ = self.get_bool(&b).clone();
|
||||
let a_circ = self.get_bool(&a);
|
||||
let b_circ = self.get_bool(&b);
|
||||
|
||||
let a_conv = self.add_conv_gate(t.clone(), a, a_circ);
|
||||
let b_conv = self.add_conv_gate(t.clone(), b, b_circ);
|
||||
@@ -218,12 +216,12 @@ impl ToABY {
|
||||
b_conv,
|
||||
ToABY::one()
|
||||
);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool(s.clone()));
|
||||
self.cache.insert(t, EmbeddedTerm::Bool(s.clone()));
|
||||
s
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
let a_circ = self.get_bv(&a).clone();
|
||||
let b_circ = self.get_bv(&b).clone();
|
||||
let a_circ = self.get_bv(&a);
|
||||
let b_circ = self.get_bv(&b);
|
||||
|
||||
let a_conv = self.add_conv_gate(t.clone(), a, a_circ);
|
||||
let b_conv = self.add_conv_gate(t.clone(), b, b_circ);
|
||||
@@ -232,7 +230,7 @@ impl ToABY {
|
||||
"{}->PutXORGate({}->PutXORGate({}->PutGTGate({}, {}), {}->PutGTGate({}, {})), {})",
|
||||
s_circ, s_circ, s_circ, a_conv, b_conv, s_circ, b_conv, a_conv, ToABY::one()
|
||||
);
|
||||
self.cache.insert(t.clone(), EmbeddedTerm::Bool(s.clone()));
|
||||
self.cache.insert(t, EmbeddedTerm::Bool(s.clone()));
|
||||
s
|
||||
}
|
||||
e => panic!("Unimplemented sort for Eq: {:?}", e),
|
||||
@@ -278,9 +276,9 @@ impl ToABY {
|
||||
let _s = self.embed_eq(t.clone(), t.cs[0].clone(), t.cs[1].clone());
|
||||
}
|
||||
Op::Ite => {
|
||||
let sel_circ = self.get_bool(&t.cs[0]).clone();
|
||||
let a_circ = self.get_bool(&t.cs[1]).clone();
|
||||
let b_circ = self.get_bool(&t.cs[2]).clone();
|
||||
let sel_circ = self.get_bool(&t.cs[0]);
|
||||
let a_circ = self.get_bool(&t.cs[1]);
|
||||
let b_circ = self.get_bool(&t.cs[2]);
|
||||
|
||||
let sel_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), sel_circ);
|
||||
let a_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), a_circ);
|
||||
@@ -306,8 +304,8 @@ impl ToABY {
|
||||
);
|
||||
}
|
||||
Op::BoolNaryOp(o) => {
|
||||
let a_circ = self.get_bool(&t.cs[0]).clone();
|
||||
let b_circ = self.get_bool(&t.cs[1]).clone();
|
||||
let a_circ = self.get_bool(&t.cs[0]);
|
||||
let b_circ = self.get_bool(&t.cs[1]);
|
||||
|
||||
let a_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), a_circ);
|
||||
let b_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), b_circ);
|
||||
@@ -422,9 +420,9 @@ impl ToABY {
|
||||
);
|
||||
}
|
||||
Op::Ite => {
|
||||
let sel_circ = self.get_bool(&t.cs[0]).clone();
|
||||
let a_circ = self.get_bv(&t.cs[1]).clone();
|
||||
let b_circ = self.get_bv(&t.cs[2]).clone();
|
||||
let sel_circ = self.get_bool(&t.cs[0]);
|
||||
let a_circ = self.get_bv(&t.cs[1]);
|
||||
let b_circ = self.get_bv(&t.cs[2]);
|
||||
|
||||
let sel_conv = self.add_conv_gate(t.clone(), t.cs[0].clone(), sel_circ);
|
||||
let a_conv = self.add_conv_gate(t.clone(), t.cs[1].clone(), a_circ);
|
||||
@@ -554,7 +552,7 @@ impl ToABY {
|
||||
|
||||
/// Given a term `t`, lower `t` to ABY Circuits
|
||||
fn lower(&mut self, t: Term) {
|
||||
let circ = self.embed(t.clone());
|
||||
let circ = self.embed(t);
|
||||
self.aby.circs.push(circ);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,12 @@ impl Debug for Ilp {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Ilp {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Ilp {
|
||||
/// Create an empty ILP
|
||||
pub fn new() -> Self {
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// Needed until https://github.com/rust-lang/rust-clippy/pull/8183 is resolved.
|
||||
#![allow(clippy::identity_op)]
|
||||
|
||||
use crate::ir::term::extras::Letified;
|
||||
use crate::ir::term::*;
|
||||
use crate::target::ilp::Ilp;
|
||||
@@ -115,8 +118,8 @@ impl ToMilp {
|
||||
let mut bounds = Vec::new();
|
||||
for x in xs {
|
||||
n += 1;
|
||||
sum = sum + x;
|
||||
bounds.push(r.clone() - x << 0);
|
||||
sum += x;
|
||||
bounds.push((r.clone() - x) << 0);
|
||||
}
|
||||
assert!(n >= 1);
|
||||
self.ilp.new_constraint(sum << (n as i32 - 1));
|
||||
@@ -172,8 +175,8 @@ impl ToMilp {
|
||||
self.bits_are_equal(&a, &b)
|
||||
}
|
||||
Sort::BitVector(n) => {
|
||||
let a = self.get_bv_uint(a).clone();
|
||||
let b = self.get_bv_uint(b).clone();
|
||||
let a = self.get_bv_uint(a);
|
||||
let b = self.get_bv_uint(b);
|
||||
self.bv_cmp_eq(&a, &b, n)
|
||||
}
|
||||
s => panic!("Unimplemented sort for Eq: {:?}", s),
|
||||
@@ -259,18 +262,18 @@ impl ToMilp {
|
||||
}
|
||||
Op::Ite => {
|
||||
let c = self.get_bool(&bv.cs[0]).clone();
|
||||
let t = self.get_bv_uint(&bv.cs[1]).clone();
|
||||
let f = self.get_bv_uint(&bv.cs[2]).clone();
|
||||
let t = self.get_bv_uint(&bv.cs[1]);
|
||||
let f = self.get_bv_uint(&bv.cs[2]);
|
||||
let ite = self.bv_ite(&c, &t, &f, n);
|
||||
self.set_bv_uint(bv, ite, n);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Not) => {
|
||||
let bits = self.get_bv_bits(&bv.cs[0]).clone();
|
||||
let bits = self.get_bv_bits(&bv.cs[0]);
|
||||
let not_bits = bits.iter().map(|bit| self.bit_not(bit)).collect();
|
||||
self.set_bv_bits(bv, not_bits);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Neg) => {
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
let x = self.get_bv_uint(&bv.cs[0]);
|
||||
// Wrong for x == 0
|
||||
let almost_neg_x = 2f64.powi(n as i32) - x.clone();
|
||||
let is_zero = self.bv_cmp_eq(&x, &0.into(), n);
|
||||
@@ -283,15 +286,14 @@ impl ToMilp {
|
||||
let ext_bits = std::iter::repeat(Expression::from(0)).take(*extra_n);
|
||||
self.set_bv_bits(bv, bits.into_iter().chain(ext_bits).collect());
|
||||
} else {
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
let x = self.get_bv_uint(&bv.cs[0]);
|
||||
self.set_bv_uint(bv, x, n);
|
||||
}
|
||||
}
|
||||
Op::BvSext(extra_n) => {
|
||||
let mut bits = self.get_bv_bits(&bv.cs[0]).into_iter().rev();
|
||||
let ext_bits =
|
||||
std::iter::repeat(bits.next().expect("sign ext empty").clone())
|
||||
.take(extra_n + 1);
|
||||
let ext_bits = std::iter::repeat(bits.next().expect("sign ext empty"))
|
||||
.take(extra_n + 1);
|
||||
|
||||
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
|
||||
}
|
||||
@@ -307,7 +309,7 @@ impl ToMilp {
|
||||
.map(|c| self.get_bv_bits(c))
|
||||
.collect::<Vec<_>>();
|
||||
let mut bits_bv_idx: Vec<Vec<Expression>> = Vec::new();
|
||||
while bits_by_bv[0].len() > 0 {
|
||||
while !bits_by_bv[0].is_empty() {
|
||||
bits_bv_idx.push(
|
||||
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
|
||||
);
|
||||
@@ -327,7 +329,7 @@ impl ToMilp {
|
||||
let values = bv
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| self.get_bv_uint(c).clone())
|
||||
.map(|c| self.get_bv_uint(c))
|
||||
.collect::<Vec<_>>();
|
||||
let r = match o {
|
||||
BvNaryOp::Add => self.bv_add(&values, n),
|
||||
@@ -463,18 +465,18 @@ impl ToMilp {
|
||||
let r = self.fresh_bv("bv_ite", n_bits);
|
||||
let m = bv_modulus(n_bits);
|
||||
self.ilp
|
||||
.new_constraint(r.clone() - a.clone() - m * (1 - s.clone()) << 0);
|
||||
.new_constraint((r.clone() - a.clone() - m * (1 - s.clone())) << 0);
|
||||
self.ilp
|
||||
.new_constraint(a.clone() - r.clone() - m * (1 - s.clone()) << 0);
|
||||
.new_constraint((a.clone() - r.clone() - m * (1 - s.clone())) << 0);
|
||||
self.ilp
|
||||
.new_constraint(r.clone() - b.clone() - m * s.clone() << 0);
|
||||
.new_constraint((r.clone() - b.clone() - m * s.clone()) << 0);
|
||||
self.ilp
|
||||
.new_constraint(b.clone() - r.clone() - m * s.clone() << 0);
|
||||
.new_constraint((b.clone() - r.clone() - m * s.clone()) << 0);
|
||||
r
|
||||
}
|
||||
|
||||
/// [Equations 7](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=915055).
|
||||
fn bv_bin_mul<'a>(&mut self, a: &Expression, b: &Expression, n_bits: usize) -> Expression {
|
||||
fn bv_bin_mul(&mut self, a: &Expression, b: &Expression, n_bits: usize) -> Expression {
|
||||
debug!("({:?}) * ({:?})", a, b);
|
||||
let a_bits = self.bit_decomp(a, n_bits);
|
||||
let bit_prods: Vec<_> = a_bits
|
||||
@@ -512,9 +514,9 @@ impl ToMilp {
|
||||
let s = self.fresh_bit("bv_le");
|
||||
let m = bv_modulus(n_bits);
|
||||
self.ilp
|
||||
.new_constraint(a.clone() - b.clone() - m * (1 - s.clone()) << -1);
|
||||
.new_constraint((a.clone() - b.clone() - m * (1 - s.clone())) << -1);
|
||||
self.ilp
|
||||
.new_constraint(a.clone() - b.clone() + m * s.clone() >> 0);
|
||||
.new_constraint((a.clone() - b.clone() + m * s.clone()) >> 0);
|
||||
s
|
||||
}
|
||||
|
||||
@@ -531,12 +533,12 @@ impl ToMilp {
|
||||
let a = if signed {
|
||||
self.get_bv_signed_int(a)
|
||||
} else {
|
||||
self.get_bv_uint(a).clone()
|
||||
self.get_bv_uint(a)
|
||||
};
|
||||
let b = if signed {
|
||||
self.get_bv_signed_int(b)
|
||||
} else {
|
||||
self.get_bv_uint(b).clone()
|
||||
self.get_bv_uint(b)
|
||||
};
|
||||
if strict {
|
||||
self.bv_cmp_lt(&b, &a, w)
|
||||
@@ -571,7 +573,7 @@ impl ToMilp {
|
||||
.get(t)
|
||||
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
|
||||
{
|
||||
EmbeddedTerm::Bool(b) => &b,
|
||||
EmbeddedTerm::Bool(b) => b,
|
||||
_ => panic!("Non-bool for {:?}", t),
|
||||
}
|
||||
}
|
||||
@@ -614,7 +616,7 @@ impl ToMilp {
|
||||
}
|
||||
|
||||
fn bv_has_bits(&self, t: &Term) -> bool {
|
||||
self.get_bv(t).borrow().bits.len() > 0
|
||||
!self.get_bv(t).borrow().bits.is_empty()
|
||||
}
|
||||
|
||||
fn get_bv_uint(&self, t: &Term) -> Expression {
|
||||
@@ -622,14 +624,14 @@ impl ToMilp {
|
||||
}
|
||||
|
||||
fn get_bv_signed_int(&mut self, t: &Term) -> Expression {
|
||||
let bits = self.get_bv_bits(t).clone();
|
||||
let bits = self.get_bv_bits(t);
|
||||
self.debitify(bits.into_iter(), true)
|
||||
}
|
||||
|
||||
fn get_bv_bits(&mut self, t: &Term) -> Vec<Expression> {
|
||||
let entry_rc = self.get_bv(t);
|
||||
let mut entry = entry_rc.borrow_mut();
|
||||
if entry.bits.len() == 0 {
|
||||
if entry.bits.is_empty() {
|
||||
entry.bits = self.bit_decomp(&entry.uint, entry.width);
|
||||
}
|
||||
entry.bits.clone()
|
||||
@@ -644,7 +646,7 @@ impl ToMilp {
|
||||
}
|
||||
|
||||
fn bv_modulus(n_bits: usize) -> f64 {
|
||||
2.0f64.powi(n_bits.try_into().unwrap()).into()
|
||||
2.0f64.powi(n_bits.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Convert this (IR) constraint system `cs` to an MILP.
|
||||
@@ -663,7 +665,7 @@ pub fn to_ilp(cs: Computation) -> Ilp {
|
||||
converter.ilp.maximize(converter.get_bool(&opt).clone());
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
converter.ilp.maximize(converter.get_bv_uint(&opt).clone());
|
||||
converter.ilp.maximize(converter.get_bv_uint(&opt));
|
||||
}
|
||||
s => panic!("Cannot optimize term of sort {}", s),
|
||||
};
|
||||
|
||||
@@ -39,7 +39,7 @@ fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
|
||||
for (v, c) in &lc.monomials {
|
||||
// ditto
|
||||
if c != &0 {
|
||||
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
|
||||
lc_bellman = lc_bellman + (int_to_ff(c), *vars.get(v).unwrap());
|
||||
}
|
||||
}
|
||||
lc_bellman
|
||||
@@ -49,7 +49,7 @@ fn modulus_as_int<F: PrimeFieldBits>() -> Integer {
|
||||
let mut bits = F::char_le_bits().to_bitvec();
|
||||
let mut acc = Integer::from(0);
|
||||
while let Some(b) = bits.pop() {
|
||||
acc = acc << 1;
|
||||
acc <<= 1;
|
||||
acc += b as u8;
|
||||
}
|
||||
acc
|
||||
@@ -133,7 +133,7 @@ pub fn parse_instance<P: AsRef<Path>, F: PrimeField>(path: P) -> Vec<F> {
|
||||
f.lines()
|
||||
.map(|line| {
|
||||
let s = line.unwrap();
|
||||
let i = Integer::from_str(&s.trim()).unwrap();
|
||||
let i = Integer::from_str(s.trim()).unwrap();
|
||||
int_to_ff(&i)
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -36,7 +36,7 @@ pub struct Lc {
|
||||
impl Lc {
|
||||
/// Is this the zero combination?
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.monomials.len() == 0 && &self.constant == &0
|
||||
self.monomials.is_empty() && self.constant == 0
|
||||
}
|
||||
/// Make this the zero combination.
|
||||
pub fn clear(&mut self) {
|
||||
@@ -55,7 +55,7 @@ impl Lc {
|
||||
}
|
||||
/// Is this a constant? If so, return that constant.
|
||||
pub fn as_const(&self) -> Option<&Integer> {
|
||||
(self.monomials.len() == 0).then(|| &self.constant)
|
||||
self.monomials.is_empty().then(|| &self.constant)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ impl std::ops::Neg for Lc {
|
||||
fn neg(mut self) -> Lc {
|
||||
self.constant = -self.constant;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (_, v) in &mut self.monomials {
|
||||
for v in &mut self.monomials.values_mut() {
|
||||
*v *= Integer::from(-1);
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
@@ -210,7 +210,7 @@ impl std::ops::MulAssign<&Integer> for Lc {
|
||||
if other == &Integer::from(0) {
|
||||
self.monomials.clear();
|
||||
} else {
|
||||
for (_, v) in &mut self.monomials {
|
||||
for v in &mut self.monomials.values_mut() {
|
||||
*v *= other;
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
@@ -233,7 +233,7 @@ impl std::ops::MulAssign<isize> for Lc {
|
||||
if other == 0 {
|
||||
self.monomials.clear();
|
||||
} else {
|
||||
for (_, v) in &mut self.monomials {
|
||||
for v in &mut self.monomials.values_mut() {
|
||||
*v *= Integer::from(other);
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
@@ -283,7 +283,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
let n = self.next_idx;
|
||||
self.next_idx += 1;
|
||||
self.signal_idxs.insert(s.clone(), n);
|
||||
self.idxs_signals.insert(n, s.clone());
|
||||
self.idxs_signals.insert(n, s);
|
||||
match (self.values.as_mut(), v) {
|
||||
(Some(vs), Some(v)) => {
|
||||
//println!("{} -> {}", &s, &v);
|
||||
@@ -331,7 +331,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
let sign = |i: &Integer| if i < &half_m { "+" } else { "-" };
|
||||
let format_i = |i: &Integer| format!("{}{}", sign(i), abs(i));
|
||||
|
||||
s.extend(format_i(&Integer::from(&a.constant)).chars());
|
||||
s.push_str(&format_i(&Integer::from(&a.constant)));
|
||||
for (idx, coeff) in &a.monomials {
|
||||
s.extend(
|
||||
format!(
|
||||
@@ -361,7 +361,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
let av = self.eval(a).unwrap();
|
||||
let bv = self.eval(b).unwrap();
|
||||
let cv = self.eval(c).unwrap();
|
||||
if &((av.clone() * &bv).rem_floor(&*self.modulus)) != &cv {
|
||||
if (av.clone() * &bv).rem_floor(&*self.modulus) != cv {
|
||||
panic!(
|
||||
"Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})",
|
||||
self.format_lc(a),
|
||||
|
||||
@@ -112,12 +112,11 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
|
||||
);
|
||||
self.clear_constraint(con_id);
|
||||
for use_id in self.uses[&var].clone() {
|
||||
if self.sub_in(var, &lc, use_id) {
|
||||
if self.r1cs.constraints[use_id].0.is_zero()
|
||||
|| self.r1cs.constraints[use_id].1.is_zero()
|
||||
{
|
||||
self.queue.push(use_id);
|
||||
}
|
||||
if self.sub_in(var, &lc, use_id)
|
||||
&& (self.r1cs.constraints[use_id].0.is_zero()
|
||||
|| self.r1cs.constraints[use_id].1.is_zero())
|
||||
{
|
||||
self.queue.push(use_id);
|
||||
}
|
||||
}
|
||||
debug_assert_eq!(0, self.uses[&var].len());
|
||||
|
||||
@@ -13,10 +13,9 @@ use rug::ops::Pow;
|
||||
use rug::Integer;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
use std::fmt::Display;
|
||||
use std::iter::ExactSizeIterator;
|
||||
use std::rc::Rc;
|
||||
|
||||
struct BvEntry {
|
||||
width: usize,
|
||||
@@ -97,7 +96,7 @@ impl ToR1cs {
|
||||
if x == 0 {
|
||||
Integer::from(0)
|
||||
} else {
|
||||
Integer::from(x.invert(&self.r1cs.modulus()).unwrap())
|
||||
x.invert(self.r1cs.modulus()).unwrap()
|
||||
}
|
||||
}),
|
||||
false,
|
||||
@@ -179,7 +178,7 @@ impl ToR1cs {
|
||||
/// Constrains `x` to fit in `n` (`signed`) bits.
|
||||
/// The LSB is at index 0.
|
||||
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Vec<Lc> {
|
||||
debug!("Bitify({}): {}", n, self.r1cs.format_lc(&x));
|
||||
debug!("Bitify({}): {}", n, self.r1cs.format_lc(x));
|
||||
let bits = self.decomp(d, x, n);
|
||||
let sum = self.debitify(bits.iter().cloned(), signed);
|
||||
self.assert_zero(sum - x);
|
||||
@@ -246,9 +245,11 @@ impl ToR1cs {
|
||||
fn nary_and<I: ExactSizeIterator<Item = Lc>>(&mut self, mut xs: I) -> Lc {
|
||||
let n = xs.len();
|
||||
if n <= 3 {
|
||||
let first = xs.next().expect("empty AND").clone();
|
||||
let first = xs.next().expect("empty AND");
|
||||
xs.fold(first, |a, x| self.mul(a, x))
|
||||
} else {
|
||||
// Needed to end the closures borrow of self, before the next line.
|
||||
#[allow(clippy::needless_collect)]
|
||||
let negs: Vec<Lc> = xs.map(|x| self.bool_not(&x)).collect();
|
||||
let a = self.nary_or(negs.into_iter());
|
||||
self.bool_not(&a)
|
||||
@@ -259,6 +260,8 @@ impl ToR1cs {
|
||||
fn nary_or<I: ExactSizeIterator<Item = Lc>>(&mut self, xs: I) -> Lc {
|
||||
let n = xs.len();
|
||||
if n <= 3 {
|
||||
// Needed to end the closures borrow of self, before the next line.
|
||||
#[allow(clippy::needless_collect)]
|
||||
let negs: Vec<Lc> = xs.map(|x| self.bool_not(&x)).collect();
|
||||
let a = self.nary_and(negs.into_iter());
|
||||
self.bool_not(&a)
|
||||
@@ -281,6 +284,8 @@ impl ToR1cs {
|
||||
debug!("Embed op: {}", c.op);
|
||||
// Handle field access once and for all
|
||||
if let Op::Field(i) = &c.op {
|
||||
// Need to borrow self in between search and insert. Could refactor.
|
||||
#[allow(clippy::map_entry)]
|
||||
if !self.cache.contains_key(&c) {
|
||||
let t = self.get_field(&c.cs[0], *i);
|
||||
self.cache.insert(c, t);
|
||||
@@ -297,7 +302,8 @@ impl ToR1cs {
|
||||
self.embed_pf(c);
|
||||
}
|
||||
Sort::Tuple(_) => {
|
||||
self.embed_tuple(c);
|
||||
// custom ops?
|
||||
panic!("Cannot embed tuple term: {}", c)
|
||||
}
|
||||
s => panic!("Unsupported sort in embed: {:?}", s),
|
||||
}
|
||||
@@ -305,18 +311,6 @@ impl ToR1cs {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
#[allow(unused_variables)]
|
||||
fn embed_tuple(&mut self, a: Term) {
|
||||
if !self.cache.contains_key(&a) {
|
||||
let t = match &a.op {
|
||||
// May want to support cunstor operators here...
|
||||
_ => panic!("Cannot embed tuple term: {}", a),
|
||||
};
|
||||
self.cache.insert(a, t);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_field(&self, tuple_term: &Term, field: usize) -> EmbeddedTerm {
|
||||
match self.cache.get(tuple_term) {
|
||||
Some(EmbeddedTerm::Tuple(v)) => v[field].clone(),
|
||||
@@ -332,8 +326,8 @@ impl ToR1cs {
|
||||
self.bits_are_equal(&a, &b)
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
let a = self.get_bv_uint(a).clone();
|
||||
let b = self.get_bv_uint(b).clone();
|
||||
let a = self.get_bv_uint(a);
|
||||
let b = self.get_bv_uint(b);
|
||||
self.are_equal(a, &b)
|
||||
}
|
||||
Sort::Field(_) => {
|
||||
@@ -344,8 +338,7 @@ impl ToR1cs {
|
||||
Sort::Tuple(sorts) => {
|
||||
let n = sorts.len();
|
||||
let eqs: Vec<Term> = (0..n).map(|i| {
|
||||
let t = term![Op::Eq; term![Op::Field(i); a.clone()], term![Op::Field(i); b.clone()]];
|
||||
t
|
||||
term![Op::Eq; term![Op::Field(i); a.clone()], term![Op::Field(i); b.clone()]]
|
||||
}).collect();
|
||||
let conj = term(Op::BoolNaryOp(BoolNaryOp::And), eqs);
|
||||
self.embed(conj.clone());
|
||||
@@ -463,16 +456,16 @@ impl ToR1cs {
|
||||
fn assert_bool(&mut self, t: &Term) {
|
||||
//println!("Embed: {}", c);
|
||||
// TODO: skip if already embedded
|
||||
if &t.op == &Op::Eq {
|
||||
if t.op == Op::Eq {
|
||||
t.cs.iter().for_each(|c| self.embed(c.clone()));
|
||||
self.assert_eq(&t.cs[0], &t.cs[1]);
|
||||
} else if &t.op == &AND {
|
||||
} else if t.op == AND {
|
||||
for c in &t.cs {
|
||||
self.assert_bool(c);
|
||||
}
|
||||
} else {
|
||||
self.embed(t.clone());
|
||||
let lc = self.get_bool(&t).clone();
|
||||
let lc = self.get_bool(t).clone();
|
||||
self.assert_zero(lc - 1);
|
||||
}
|
||||
}
|
||||
@@ -489,12 +482,12 @@ impl ToR1cs {
|
||||
let a = if signed {
|
||||
self.get_bv_signed_int(a)
|
||||
} else {
|
||||
self.get_bv_uint(a).clone()
|
||||
self.get_bv_uint(a)
|
||||
};
|
||||
let b = if signed {
|
||||
self.get_bv_signed_int(b)
|
||||
} else {
|
||||
self.get_bv_uint(b).clone()
|
||||
self.get_bv_uint(b)
|
||||
};
|
||||
// Use the fact: a > b <=> a - 1 >= b
|
||||
self.bv_ge(if strict { a - 1 } else { a }, &b, w)
|
||||
@@ -551,18 +544,18 @@ impl ToR1cs {
|
||||
}
|
||||
Op::Ite => {
|
||||
let c = self.get_bool(&bv.cs[0]).clone();
|
||||
let t = self.get_bv_uint(&bv.cs[1]).clone();
|
||||
let f = self.get_bv_uint(&bv.cs[2]).clone();
|
||||
let t = self.get_bv_uint(&bv.cs[1]);
|
||||
let f = self.get_bv_uint(&bv.cs[2]);
|
||||
let ite = self.ite(c, t, &f);
|
||||
self.set_bv_uint(bv, ite, n);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Not) => {
|
||||
let bits = self.get_bv_bits(&bv.cs[0]).clone();
|
||||
let bits = self.get_bv_bits(&bv.cs[0]);
|
||||
let not_bits = bits.iter().map(|bit| self.bool_not(bit)).collect();
|
||||
self.set_bv_bits(bv, not_bits);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Neg) => {
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
let x = self.get_bv_uint(&bv.cs[0]);
|
||||
// Wrong for x == 0
|
||||
let almost_neg_x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - &x;
|
||||
let is_zero = self.is_zero(x);
|
||||
@@ -575,15 +568,14 @@ impl ToR1cs {
|
||||
let ext_bits = std::iter::repeat(self.r1cs.zero()).take(*extra_n);
|
||||
self.set_bv_bits(bv, bits.into_iter().chain(ext_bits).collect());
|
||||
} else {
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
let x = self.get_bv_uint(&bv.cs[0]);
|
||||
self.set_bv_uint(bv, x, n);
|
||||
}
|
||||
}
|
||||
Op::BvSext(extra_n) => {
|
||||
let mut bits = self.get_bv_bits(&bv.cs[0]).into_iter().rev();
|
||||
let ext_bits =
|
||||
std::iter::repeat(bits.next().expect("sign ext empty").clone())
|
||||
.take(extra_n + 1);
|
||||
let ext_bits = std::iter::repeat(bits.next().expect("sign ext empty"))
|
||||
.take(extra_n + 1);
|
||||
|
||||
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
|
||||
}
|
||||
@@ -604,7 +596,7 @@ impl ToR1cs {
|
||||
.map(|c| self.get_bv_bits(c))
|
||||
.collect::<Vec<_>>();
|
||||
let mut bits_bv_idx: Vec<Vec<Lc>> = Vec::new();
|
||||
while bits_by_bv[0].len() > 0 {
|
||||
while !bits_by_bv[0].is_empty() {
|
||||
bits_bv_idx.push(
|
||||
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
|
||||
);
|
||||
@@ -624,7 +616,7 @@ impl ToR1cs {
|
||||
let values = bv
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| self.get_bv_uint(c).clone())
|
||||
.map(|c| self.get_bv_uint(c))
|
||||
.collect::<Vec<_>>();
|
||||
let (res, width) = match o {
|
||||
BvNaryOp::Add => {
|
||||
@@ -663,14 +655,12 @@ impl ToR1cs {
|
||||
let b = self.get_bv_uint(&bv.cs[1]);
|
||||
match o {
|
||||
BvBinOp::Sub => {
|
||||
let sum = a.clone() + &(Integer::from(1) << n as u32) - &b;
|
||||
let sum = a + &(Integer::from(1) << n as u32) - &b;
|
||||
let mut bits = self.bitify("sub", &sum, n + 1, false);
|
||||
bits.truncate(n);
|
||||
self.set_bv_bits(bv, bits);
|
||||
}
|
||||
BvBinOp::Udiv | BvBinOp::Urem => {
|
||||
let b = b.clone();
|
||||
let a = a.clone();
|
||||
let is_zero = self.is_zero(b.clone());
|
||||
let (q_v, r_v) = self
|
||||
.r1cs
|
||||
@@ -705,8 +695,7 @@ impl ToR1cs {
|
||||
}
|
||||
// Shift cases
|
||||
_ => {
|
||||
let r = b.clone();
|
||||
let a = a.clone();
|
||||
let r = b;
|
||||
let b = bitsize(n - 1);
|
||||
assert!(1 << b == n);
|
||||
let mut rb = self.get_bv_bits(&bv.cs[1]);
|
||||
@@ -776,7 +765,7 @@ impl ToR1cs {
|
||||
.get(t)
|
||||
.unwrap_or_else(|| panic!("Missing wire for {:?}", t))
|
||||
{
|
||||
EmbeddedTerm::Bool(b) => &b,
|
||||
EmbeddedTerm::Bool(b) => b,
|
||||
_ => panic!("Non-boolean for {:?}", t),
|
||||
}
|
||||
}
|
||||
@@ -818,7 +807,7 @@ impl ToR1cs {
|
||||
}
|
||||
|
||||
fn bv_has_bits(&self, t: &Term) -> bool {
|
||||
self.get_bv(t).borrow().bits.len() > 0
|
||||
!self.get_bv(t).borrow().bits.is_empty()
|
||||
}
|
||||
|
||||
fn get_bv_uint(&self, t: &Term) -> Lc {
|
||||
@@ -826,14 +815,14 @@ impl ToR1cs {
|
||||
}
|
||||
|
||||
fn get_bv_signed_int(&mut self, t: &Term) -> Lc {
|
||||
let bits = self.get_bv_bits(t).clone();
|
||||
let bits = self.get_bv_bits(t);
|
||||
self.debitify(bits.into_iter(), true)
|
||||
}
|
||||
|
||||
fn get_bv_bits(&mut self, t: &Term) -> Vec<Lc> {
|
||||
let entry_rc = self.get_bv(t);
|
||||
let mut entry = entry_rc.borrow_mut();
|
||||
if entry.bits.len() == 0 {
|
||||
if entry.bits.is_empty() {
|
||||
entry.bits = self.bitify("getbits", &entry.uint, entry.width, false);
|
||||
}
|
||||
entry.bits.clone()
|
||||
@@ -871,14 +860,16 @@ impl ToR1cs {
|
||||
match o {
|
||||
PfNaryOp::Add => args.fold(self.r1cs.zero(), std::ops::Add::add),
|
||||
PfNaryOp::Mul => {
|
||||
// Needed to end the above closures borrow of self, before the mul call
|
||||
#[allow(clippy::needless_collect)]
|
||||
let args = args.cloned().collect::<Vec<_>>();
|
||||
let mut args_iter = args.into_iter();
|
||||
let first = args_iter.next().unwrap();
|
||||
args_iter.fold(first, |a, b| self.mul(a, b.clone()))
|
||||
args_iter.fold(first, |a, b| self.mul(a, b))
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::UbvToPf(_) => self.get_bv_uint(&c.cs[0]).clone(),
|
||||
Op::UbvToPf(_) => self.get_bv_uint(&c.cs[0]),
|
||||
Op::PfUnOp(PfUnOp::Neg) => -self.get_pf(&c.cs[0]).clone(),
|
||||
Op::PfUnOp(PfUnOp::Recip) => {
|
||||
let x = self.get_pf(&c.cs[0]).clone();
|
||||
|
||||
@@ -24,7 +24,7 @@ struct SmtDisp<'a, T>(pub &'a T);
|
||||
impl<'a, T: Expr2Smt<()> + 'a> Display for SmtDisp<'a, T> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
let mut s = Vec::new();
|
||||
<T as Expr2Smt<()>>::expr_to_smt2(&self.0, &mut s, ()).unwrap();
|
||||
<T as Expr2Smt<()>>::expr_to_smt2(self.0, &mut s, ()).unwrap();
|
||||
write!(f, "{}", std::str::from_utf8(&s).unwrap())?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -34,7 +34,7 @@ struct SmtSortDisp<'a, T>(pub &'a T);
|
||||
impl<'a, T: Sort2Smt + 'a> Display for SmtSortDisp<'a, T> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
let mut s = Vec::new();
|
||||
<T as Sort2Smt>::sort_to_smt2(&self.0, &mut s).unwrap();
|
||||
<T as Sort2Smt>::sort_to_smt2(self.0, &mut s).unwrap();
|
||||
write!(f, "{}", std::str::from_utf8(&s).unwrap())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,6 +14,12 @@ pub struct OnceQueue<T> {
|
||||
set: FxHashSet<T>,
|
||||
}
|
||||
|
||||
impl<T: Eq + Hash + Clone> Default for OnceQueue<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq + Hash + Clone> OnceQueue<T> {
|
||||
/// Add to the queue. If `t` is already present, it is dropped.
|
||||
pub fn push(&mut self, t: T) {
|
||||
|
||||
Reference in New Issue
Block a user