From ec9f1c464aae2c718df1e15ea12fb4ab77ee1402 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Sat, 6 Feb 2021 22:48:01 -0800 Subject: [PATCH] Reasonably complete BV->R1CS --- examples/gen.rs | 12 +- src/ir/term/bv.rs | 208 ++++++++++++++++++ src/ir/term/dist.rs | 207 ++++++++++++++++++ src/ir/{term.rs => term/mod.rs} | 308 +++++++++++---------------- src/ir/to_r1cs.rs | 363 +++++++++++++++++++++++++++----- src/lib.rs | 2 +- src/target/r1cs/mod.rs | 36 ++-- src/util/hc/mod.rs | 1 + 8 files changed, 869 insertions(+), 268 deletions(-) create mode 100644 src/ir/term/bv.rs create mode 100644 src/ir/term/dist.rs rename src/ir/{term.rs => term/mod.rs} (81%) diff --git a/examples/gen.rs b/examples/gen.rs index a978949d..afa3f4dd 100644 --- a/examples/gen.rs +++ b/examples/gen.rs @@ -1,11 +1,15 @@ use circ::ir; use rand::distributions::Distribution; - fn main() { let mut rng = rand::thread_rng(); - for i in 0..100 { - let t = ir::term::BoolDist(6).sample(&mut rng); - println!("Term: {:#?}", t) + for _ in 0..100 { + let d = ir::term::dist::FixedSizeDist { + bv_width: 4, + sort: ir::term::Sort::Bool, + size: 6, + }; + let t = d.sample(&mut rng); + println!("Term: {}", t) } } diff --git a/src/ir/term/bv.rs b/src/ir/term/bv.rs new file mode 100644 index 00000000..b777d559 --- /dev/null +++ b/src/ir/term/bv.rs @@ -0,0 +1,208 @@ +use rug::Integer; + +use std::fmt::{self, Display, Formatter}; + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct BitVector { + uint: Integer, + width: usize, +} + +macro_rules! bv_arith_impl { + ($Trait:path, $fn:ident) => { + impl $Trait for BitVector { + type Output = Self; + fn $fn(self, other: Self) -> Self { + assert_eq!(self.width, other.width); + let r = BitVector { + uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32), + width: self.width, + }; + r.check(std::stringify!($fn)); + r + } + } + }; +} + +bv_arith_impl!(std::ops::Add, add); +bv_arith_impl!(std::ops::Sub, sub); +bv_arith_impl!(std::ops::Mul, mul); +//bv_arith_impl!(std::ops::Div, div); +//bv_arith_impl!(std::ops::Rem, rem); +bv_arith_impl!(std::ops::BitAnd, bitand); +bv_arith_impl!(std::ops::BitOr, bitor); +bv_arith_impl!(std::ops::BitXor, bitxor); + +/// SMT-semantics implementation of unsigned division (udiv). +/// +/// If the divisor is zero, returns the all-ones vector. +impl std::ops::Div<&BitVector> for BitVector { + type Output = Self; + fn div(self, other: &Self) -> Self { + assert_eq!(self.width, other.width); + if other.uint == 0 { + let r = BitVector { + uint: (Integer::from(1) << self.width as u32) - 1, + width: self.width, + }; + r.check("div"); + r + } else { + let r = BitVector { + uint: self.uint / &other.uint, + width: self.width, + }; + r.check("div"); + r + } + } +} + +/// SMT-semantics implementation of unsigned remainder (urem). +/// +/// If the divisor is zero, returns the all-ones vector. +impl std::ops::Rem<&BitVector> for BitVector { + type Output = Self; + fn rem(self, other: &Self) -> Self { + assert_eq!(self.width, other.width); + if other.uint == 0 { + self + } else { + let r = BitVector { + uint: self.uint % &other.uint, + width: self.width, + }; + r.check("rem"); + r + } + } +} + +impl std::ops::Shl for BitVector { + type Output = Self; + fn shl(self, other: Self) -> Self { + assert_eq!(self.width, other.width); + let r = BitVector { + uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32), + width: self.width, + }; + r.check("shl"); + r + } +} + +impl std::ops::Neg for BitVector { + type Output = Self; + fn neg(self) -> Self { + let r = BitVector { + uint: ((Integer::from(1) << self.width as u32) - self.uint) + .keep_bits(self.width as u32), + width: self.width, + }; + r.check("neg"); + r + } +} + +impl std::ops::Not for BitVector { + type Output = Self; + fn not(self) -> Self { + let r = BitVector { + uint: (Integer::from(1) << self.width as u32) - 1 - self.uint, + width: self.width, + }; + r.check("not"); + r + } +} + +impl BitVector { + #[track_caller] + #[inline] + pub fn check(&self, location: &str) { + debug_assert!( + self.uint >= 0, + "Too small bitvector: {:?}, {}\n at {}", + self, + self.uint.significant_bits(), + location + ); + debug_assert!( + (self.uint.significant_bits() as usize) <= self.width, + "Too big bitvector: {:?}, {}\n at {}", + self, + self.uint.significant_bits(), + location + ); + } + pub fn ashr(mut self, other: Self) -> Self { + assert_eq!(self.width, other.width); + let n = other.uint.to_u32().unwrap(); + let b = self.uint.get_bit(self.width as u32 - 1); + self.uint >>= n; + for i in 0..n { + self.uint.set_bit(self.width as u32 - 1 - i, b); + } + self.check("ashr"); + self + } + pub fn lshr(self, other: Self) -> Self { + assert_eq!(self.width, other.width); + let r = BitVector { + uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32), + width: self.width, + }; + r.check("lshr"); + r + } + pub fn concat(self, other: Self) -> Self { + let r = BitVector { + uint: (self.uint << other.width as u32) | other.uint, + width: self.width + other.width, + }; + r.check("concat"); + r + } + pub fn extract(self, high: usize, low: usize) -> Self { + let r = BitVector { + uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32), + width: high - low + 1, + }; + r.check("extract"); + r + } + pub fn as_sint(&self) -> Integer { + if self.uint.significant_bits() as usize == self.width { + self.uint.clone() - (Integer::from(1) << self.width as u32) + } else { + self.uint.clone() + } + } + pub fn uint(&self) -> &Integer { + &self.uint + } + pub fn width(&self) -> usize { + self.width + } + #[track_caller] + pub fn new(uint: Integer, width: usize) -> BitVector { + let r = BitVector { uint, width }; + r.check("new"); + r + } +} + +impl Display for BitVector { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "#b")?; + for i in 0..self.width { + write!( + f, + "{}", + self.uint.get_bit((self.width - i - 1) as u32) as u8 + )?; + } + Ok(()) + } +} diff --git a/src/ir/term/dist.rs b/src/ir/term/dist.rs new file mode 100644 index 00000000..03e6d3fd --- /dev/null +++ b/src/ir/term/dist.rs @@ -0,0 +1,207 @@ +use super::*; + +// A distribution of boolean terms with some size. +// All subterms are booleans. +pub struct PureBoolDist(pub usize); + +// A distribution of n usizes that sum to this value. +// (n, sum) +pub struct Sum(usize, usize); +impl rand::distributions::Distribution> for Sum { + fn sample(&self, rng: &mut R) -> Vec { + use rand::seq::SliceRandom; + let mut acc = self.1; + let mut ns = Vec::new(); + assert!(acc == 0 || self.0 > 0); + while acc > 0 && ns.len() < self.0 { + let x = rng.gen_range(0..acc); + acc -= x; + ns.push(x); + } + while ns.len() < self.0 { + ns.push(0); + } + if acc > 0 { + *ns.last_mut().unwrap() += acc; + } + ns.shuffle(rng); + ns + } +} + +impl rand::distributions::Distribution for PureBoolDist { + fn sample(&self, rng: &mut R) -> Term { + use rand::seq::SliceRandom; + let ops = &[ + Op::Const(Value::Bool(rng.gen())), + Op::Var( + std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]) + .unwrap() + .to_owned(), + Sort::Bool, + ), + Op::Not, + Op::Implies, + Op::BoolNaryOp(BoolNaryOp::Or), + Op::BoolNaryOp(BoolNaryOp::And), + Op::BoolNaryOp(BoolNaryOp::Xor), + ]; + let o = match self.0 { + 1 => ops[..2].choose(rng), // arity 0 + 2 => ops[2..3].choose(rng), // arity 1 + _ => ops[2..].choose(rng), // others + } + .unwrap() + .clone(); + // Now, self.0 is a least arity+1 + let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0)); + let excess = self.0 - 1 - a; + let ns = Sum(a, excess).sample(rng); + let subterms = ns + .into_iter() + .map(|n| PureBoolDist(n + 1).sample(rng)) + .collect::>(); + term(o, subterms) + } +} + +pub struct FixedSizeDist { + pub size: usize, + pub bv_width: usize, + pub sort: Sort, +} + +impl FixedSizeDist { + fn with_size(&self, size: usize) -> Self { + FixedSizeDist { + size, + sort: self.sort.clone(), + bv_width: self.bv_width, + } + } + fn with_sort(&self, sort: Sort) -> Self { + FixedSizeDist { + size: self.size, + sort, + bv_width: self.bv_width, + } + } +} + +pub struct UniformBitVector(pub usize); + +impl rand::distributions::Distribution for UniformBitVector { + fn sample(&self, rng: &mut R) -> BitVector { + let mut rug_rng = rug::rand::RandState::new_mersenne_twister(); + rug_rng.seed(&Integer::from(rng.next_u32())); + BitVector::new( + Integer::from(Integer::random_bits(self.0 as u32, &mut rug_rng)), + self.0, + ) + } +} + +impl rand::distributions::Distribution for FixedSizeDist { + fn sample(&self, rng: &mut R) -> Term { + use rand::seq::SliceRandom; + match self.sort.clone() { + Sort::Bool => { + let ops = &[ + Op::Const(Value::Bool(rng.gen())), + Op::Var( + std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]) + .unwrap() + .to_owned(), + Sort::Bool, + ), + Op::Not, // 2 + Op::Implies, + Op::Eq, + Op::BvBinPred(BvBinPred::Sge), + Op::BvBinPred(BvBinPred::Sgt), + Op::BvBinPred(BvBinPred::Sle), + Op::BvBinPred(BvBinPred::Slt), + Op::BvBinPred(BvBinPred::Uge), + Op::BvBinPred(BvBinPred::Ugt), + Op::BvBinPred(BvBinPred::Ule), + Op::BvBinPred(BvBinPred::Ult), + Op::BoolNaryOp(BoolNaryOp::Or), + Op::BoolNaryOp(BoolNaryOp::And), + Op::BoolNaryOp(BoolNaryOp::Xor), + Op::Ite, + ]; + let o = match self.size { + 1 => ops[..2].choose(rng), // arity 0 + 2 => ops[2..3].choose(rng), // arity 1 + 3 => ops[2..16].choose(rng), // arity 2 + _ => ops[2..].choose(rng), // others + } + .unwrap() + .clone(); + // Now, self.0 is a least arity+1 + let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size)); + let excess = self.size - 1 - a; + let ns = Sum(a, excess).sample(rng); + let sort = match o { + Op::Eq => [Sort::Bool, Sort::BitVector(self.bv_width)] + .choose(rng) + .unwrap() + .clone(), + Op::BvBinPred(_) => Sort::BitVector(self.bv_width), + _ => Sort::Bool, + }; + let subterms = ns + .into_iter() + .map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng)) + .collect::>(); + term(o, subterms) + } + Sort::BitVector(w) => { + let ops = &[ + Op::Const(Value::BitVector(UniformBitVector(w).sample(rng))), + Op::Var( + format!( + "{}_bv{}", + std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]).unwrap(), + w + ), + Sort::BitVector(w), + ), + Op::BvUnOp(BvUnOp::Neg), + Op::BvUnOp(BvUnOp::Not), + Op::BvUext(rng.gen_range(0..w)), + Op::BvSext(rng.gen_range(0..w)), + Op::BvBinOp(BvBinOp::Sub), + Op::BvNaryOp(BvNaryOp::Or), + Op::BvNaryOp(BvNaryOp::And), + Op::BvNaryOp(BvNaryOp::Xor), + Op::BvNaryOp(BvNaryOp::Add), + Op::BvNaryOp(BvNaryOp::Mul), + // Add ITEs + ]; + let o = match self.size { + 1 => ops[..2].choose(rng), // arity 0 + 2 => ops[2..4].choose(rng), // arity 1 + _ => ops[2..].choose(rng), // others + } + .unwrap() + .clone(); + let sort = match o { + Op::BvUext(ww) => Sort::BitVector(w - ww), + Op::BvSext(ww) => Sort::BitVector(w - ww), + _ => Sort::BitVector(w), + }; + // Now, self.0 is a least arity+1 + let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size)); + let excess = self.size - 1 - a; + let ns = Sum(a, excess).sample(rng); + let subterms = ns + .into_iter() + .map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng)) + .collect::>(); + term(o, subterms) + } + s => panic!("Unsampleabl sort: {}", s), + } + } +} diff --git a/src/ir/term.rs b/src/ir/term/mod.rs similarity index 81% rename from src/ir/term.rs rename to src/ir/term/mod.rs index fd5930a9..2d59f49a 100644 --- a/src/ir/term.rs +++ b/src/ir/term/mod.rs @@ -6,6 +6,11 @@ use std::collections::HashSet; use std::fmt::{self, Display, Formatter}; use std::sync::{Arc, RwLock}; +pub mod bv; +pub mod dist; + +pub use bv::BitVector; + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Op { Ite, @@ -352,16 +357,16 @@ impl Display for PfUnOp { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct TermData { pub op: Op, - pub children: Vec, + pub cs: Vec, } impl Display for TermData { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - if let Op::Var(..) = &self.op { + if self.op.arity() == Some(0) { write!(f, "{}", self.op) } else { write!(f, "({}", self.op)?; - for c in &self.children { + for c in &self.cs { write!(f, " {}", c)?; } write!(f, ")") @@ -385,90 +390,6 @@ impl Display for FieldElem { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct BitVector { - pub uint: Integer, - pub width: usize, -} - -macro_rules! bv_arith_impl { - ($Trait:path, $fn:ident) => { - impl $Trait for BitVector { - type Output = Self; - fn $fn(self, other: Self) -> Self { - assert_eq!(self.width, other.width); - BitVector { - uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32), - width: self.width, - } - } - } - }; -} - -bv_arith_impl!(std::ops::Add, add); -bv_arith_impl!(std::ops::Sub, sub); -bv_arith_impl!(std::ops::Mul, mul); -bv_arith_impl!(std::ops::Div, div); -bv_arith_impl!(std::ops::Rem, rem); - -impl std::ops::Shl for BitVector { - type Output = Self; - fn shl(self, other: Self) -> Self { - assert_eq!(self.width, other.width); - BitVector { - uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32), - width: self.width, - } - } -} - -impl BitVector { - pub fn ashr(mut self, other: Self) -> Self { - assert_eq!(self.width, other.width); - let n = other.uint.to_u32().unwrap(); - let b = self.uint.get_bit(self.width as u32 - 1); - self.uint >>= n; - for i in 0..n { - self.uint.set_bit(self.width as u32 - 1 - i, b); - } - self - } - pub fn lshr(self, other: Self) -> Self { - assert_eq!(self.width, other.width); - BitVector { - uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32), - width: self.width, - } - } - pub fn concat(self, other: Self) -> Self { - BitVector { - uint: (self.uint << other.width as u32) | other.uint, - width: self.width + other.width, - } - } - pub fn extract(self, high: usize, low: usize) -> Self { - BitVector { - uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32), - width: high - low + 1, - } - } -} - -impl Display for BitVector { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "#b")?; - for i in 0..self.width { - write!( - f, - "#{}", - self.uint.get_bit((self.width - i - 1) as u32) as u8 - )?; - } - Ok(()) - } -} - #[derive(Clone, PartialEq, Debug)] pub enum Value { BitVector(BitVector), @@ -517,6 +438,30 @@ pub enum Sort { Array(Box, Box), } +impl Sort { + pub fn as_bv(&self) -> usize { + if let Sort::BitVector(w) = self { + *w + } else { + panic!("{} is not a bit-vector", self) + } + } +} + +impl Display for Sort { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + Sort::Bool => write!(f, "bool"), + Sort::BitVector(n) => write!(f, "(bv {})", n), + Sort::Int => write!(f, "int"), + Sort::F32 => write!(f, "f32"), + Sort::F64 => write!(f, "f64"), + Sort::Field(i) => write!(f, "(mod {})", i), + Sort::Array(k, v) => write!(f, "(array {} {})", k, v), + } + } +} + pub type Term = HConsed; // "Temporary" terms. pub type TTerm = WHConsed; @@ -537,9 +482,10 @@ impl Value { Value::Int(_) => Sort::Int, Value::F64(_) => Sort::F64, Value::F32(_) => Sort::F32, - Value::BitVector(b) => Sort::BitVector(b.width), + Value::BitVector(b) => Sort::BitVector(b.width()), } } + #[track_caller] pub fn as_bool(&self) -> bool { if let Value::Bool(b) = self { *b @@ -547,6 +493,7 @@ impl Value { panic!("Not a bool: {}", self) } } + #[track_caller] pub fn as_bv(&self) -> &BitVector { if let Value::BitVector(b) = self { b @@ -621,7 +568,7 @@ pub fn check(t: Term) -> Result { } { let mut term_tys = TERM_TYPES.write().unwrap(); - // to_check is a stack of (node, children checked) pairs. + // to_check is a stack of (node, cs checked) pairs. let mut to_check = vec![(t.clone(), false)]; while to_check.len() > 0 { let back = to_check.last_mut().unwrap(); @@ -640,13 +587,13 @@ pub fn check(t: Term) -> Result { } if !back.1 { back.1 = true; - for c in back.0.children.clone() { + for c in back.0.cs.clone() { to_check.push((c, false)); } } else { let tys = back .0 - .children + .cs .iter() .map(|c| term_tys.get(&c.to_weak()).unwrap()) .collect::>(); @@ -781,51 +728,101 @@ pub fn eval(t: &Term, h: &HashMap) -> Value { for c in PostOrderIter::new(t.clone()) { let v = match &c.op { Op::Var(n, _) => h.get(n).unwrap().clone(), - Op::Eq => { - Value::Bool(vs.get(&c.children[0]).unwrap() == vs.get(&c.children[1]).unwrap()) - } - Op::Not => Value::Bool(!vs.get(&c.children[0]).unwrap().as_bool()), + Op::Eq => Value::Bool(vs.get(&c.cs[0]).unwrap() == vs.get(&c.cs[1]).unwrap()), + Op::Not => Value::Bool(!vs.get(&c.cs[0]).unwrap().as_bool()), Op::Implies => Value::Bool( - !vs.get(&c.children[0]).unwrap().as_bool() - || vs.get(&c.children[1]).unwrap().as_bool(), + !vs.get(&c.cs[0]).unwrap().as_bool() || vs.get(&c.cs[1]).unwrap().as_bool(), ), Op::BoolNaryOp(BoolNaryOp::Or) => { - Value::Bool(c.children.iter().any(|c| vs.get(c).unwrap().as_bool())) + Value::Bool(c.cs.iter().any(|c| vs.get(c).unwrap().as_bool())) } Op::BoolNaryOp(BoolNaryOp::And) => { - Value::Bool(c.children.iter().all(|c| vs.get(c).unwrap().as_bool())) + Value::Bool(c.cs.iter().all(|c| vs.get(c).unwrap().as_bool())) } Op::BoolNaryOp(BoolNaryOp::Xor) => Value::Bool( - c.children - .iter() + c.cs.iter() .map(|c| vs.get(c).unwrap().as_bool()) .fold(false, std::ops::BitXor::bitxor), ), - Op::BvBit(i) => Value::Bool( - vs.get(&c.children[0]) - .unwrap() - .as_bv() - .uint - .get_bit(*i as u32), - ), + Op::BvBit(i) => { + Value::Bool(vs.get(&c.cs[0]).unwrap().as_bv().uint().get_bit(*i as u32)) + } Op::BvConcat => Value::BitVector({ - let mut it = c - .children - .iter() - .map(|c| vs.get(c).unwrap().as_bv().clone()); + let mut it = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone()); let f = it.next().unwrap(); it.fold(f, BitVector::concat) }), - Op::BvExtract(h, l) => Value::BitVector( - vs.get(&c.children[0]) - .unwrap() - .as_bv() - .clone() - .extract(*h, *l), - ), + Op::BvExtract(h, l) => { + Value::BitVector(vs.get(&c.cs[0]).unwrap().as_bv().clone().extract(*h, *l)) + } Op::Const(v) => v.clone(), + Op::BvBinOp(o) => Value::BitVector({ + let a = vs.get(&c.cs[0]).unwrap().as_bv().clone(); + let b = vs.get(&c.cs[1]).unwrap().as_bv().clone(); + match o { + BvBinOp::Udiv => a / &b, + BvBinOp::Urem => a % &b, + BvBinOp::Sub => a - b, + BvBinOp::Ashr => a.ashr(b), + BvBinOp::Lshr => a.lshr(b), + BvBinOp::Shl => a << b, + } + }), + Op::BvUnOp(o) => Value::BitVector({ + let a = vs.get(&c.cs[0]).unwrap().as_bv().clone(); + match o { + BvUnOp::Not => !a, + BvUnOp::Neg => -a, + } + }), + Op::BvNaryOp(o) => Value::BitVector({ + let mut xs = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone()); + let f = xs.next().unwrap(); + xs.fold( + f, + match o { + BvNaryOp::Add => std::ops::Add::add, + BvNaryOp::Mul => std::ops::Mul::mul, + BvNaryOp::Xor => std::ops::BitXor::bitxor, + BvNaryOp::Or => std::ops::BitOr::bitor, + BvNaryOp::And => std::ops::BitAnd::bitand, + }, + ) + }), + Op::BvSext(w) => Value::BitVector({ + let a = vs.get(&c.cs[0]).unwrap().as_bv().clone(); + let mask = ((Integer::from(1) << *w as u32) - 1) + * Integer::from(a.uint().get_bit(a.width() as u32 - 1)); + BitVector::new(a.uint() | (mask << a.width() as u32), a.width() + w) + }), + Op::BvUext(w) => Value::BitVector({ + let a = vs.get(&c.cs[0]).unwrap().as_bv().clone(); + BitVector::new(a.uint().clone(), a.width() + w) + }), + Op::Ite => if vs.get(&c.cs[0]).unwrap().as_bool() { + vs.get(&c.cs[1]) + } else { + vs.get(&c.cs[2]) + } + .unwrap() + .clone(), + Op::BvBinPred(o) => Value::Bool({ + let a = vs.get(&c.cs[0]).unwrap().as_bv(); + let b = vs.get(&c.cs[1]).unwrap().as_bv(); + match o { + BvBinPred::Sge => a.as_sint() >= b.as_sint(), + BvBinPred::Sgt => a.as_sint() > b.as_sint(), + BvBinPred::Sle => a.as_sint() <= b.as_sint(), + BvBinPred::Slt => a.as_sint() < b.as_sint(), + BvBinPred::Uge => a.uint() >= b.uint(), + BvBinPred::Ugt => a.uint() > b.uint(), + BvBinPred::Ule => a.uint() <= b.uint(), + BvBinPred::Ult => a.uint() < b.uint(), + } + }), o => unimplemented!("eval: {:?}", o), }; + //println!("Eval {}\nAs {}", c, v); vs.insert(c.clone(), v); } vs.get(t).unwrap().clone() @@ -854,9 +851,9 @@ pub fn leaf_term(op: Op) -> Term { term(op, Vec::new()) } -pub fn term(op: Op, children: Vec) -> Term { +pub fn term(op: Op, cs: Vec) -> Term { use hashconsing::HashConsign; - let t = TERM_FACTORY.mk(TermData { op, children }); + let t = TERM_FACTORY.mk(TermData { op, cs }); check(t.clone()).unwrap(); t } @@ -868,76 +865,11 @@ macro_rules! term { }; } -// A distribution of boolean terms with some size. -// All subterms are booleans. -pub struct BoolDist(pub usize); - -// A distribution of n usizes that sum to this value. -// (n, sum) -pub struct Sum(usize, usize); -impl rand::distributions::Distribution> for Sum { - fn sample(&self, rng: &mut R) -> Vec { - use rand::seq::SliceRandom; - let mut acc = self.1; - let mut ns = Vec::new(); - assert!(acc == 0 || self.0 > 0); - while acc > 0 && ns.len() < self.0 { - let x = rng.gen_range(0..acc); - acc -= x; - ns.push(x); - } - while ns.len() < self.0 { - ns.push(0); - } - if acc > 0 { - *ns.last_mut().unwrap() += acc; - } - ns.shuffle(rng); - ns - } -} - -impl rand::distributions::Distribution for BoolDist { - fn sample(&self, rng: &mut R) -> Term { - use rand::seq::SliceRandom; - let ops = &[ - Op::Const(Value::Bool(rng.gen())), - Op::Var( - std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]) - .unwrap() - .to_owned(), - Sort::Bool, - ), - Op::Not, - Op::Implies, - Op::BoolNaryOp(BoolNaryOp::Or), - Op::BoolNaryOp(BoolNaryOp::And), - Op::BoolNaryOp(BoolNaryOp::Xor), - ]; - let o = match self.0 { - 1 => ops[..2].choose(rng), // arity 0 - 2 => ops[2..3].choose(rng), // arity 1 - _ => ops[2..].choose(rng), // others - } - .unwrap() - .clone(); - // Now, self.0 is a least arity+1 - let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0)); - let excess = self.0 - 1 - a; - let ns = Sum(a, excess).sample(rng); - let subterms = ns - .into_iter() - .map(|n| BoolDist(n + 1).sample(rng)) - .collect::>(); - term(o, subterms) - } -} - pub type TermMap = hashconsing::coll::HConMap; pub type TermSet = hashconsing::coll::HConSet; pub struct PostOrderIter { - // (children stacked, term) + // (cs stacked, term) stack: Vec<(bool, Term)>, visited: TermSet, } @@ -959,7 +891,7 @@ impl std::iter::Iterator for PostOrderIter { self.stack.pop(); } else if !children_pushed { self.stack.last_mut().unwrap().0 = true; - let cs = self.stack.last().unwrap().1.children.clone(); + let cs = self.stack.last().unwrap().1.cs.clone(); self.stack.extend(cs.into_iter().map(|c| (false, c))); } else { break; diff --git a/src/ir/to_r1cs.rs b/src/ir/to_r1cs.rs index a653d501..bb5a5bcd 100644 --- a/src/ir/to_r1cs.rs +++ b/src/ir/to_r1cs.rs @@ -118,32 +118,45 @@ impl ToR1cs { self.values .as_ref() .map(|vs| match vs.get(var).expect("missing value") { - Value::BitVector(BitVector { uint, .. }) => uint.clone(), + Value::BitVector(b) => b.uint().clone(), v => panic!("{} should be a bool, but is {:?}", var, v), }) } /// Given wire `x`, returns a vector of `n` wires which are the bits of `x`. - /// Constrains `x` to fit in `n` (unsigned) bits. + /// They *have not* been constrained to sum to `x`. + /// They have values according the the (infinite) two's complement representation of `x`. /// The LSB is at index 0. - fn bitify(&mut self, d: &D, x: &Lc, n: usize) -> Vec { + fn decomp(&mut self, d: &D, x: &Lc, n: usize) -> Vec { let x_val = self.r1cs.eval(x); - let bits = (0..n) + (0..n) .map(|i| { self.fresh_bit( + // We get the right repr here because of infinite two's complement. &format!("{}_b{}", d, i), x_val.as_ref().map(|x| Integer::from(x.get_bit(i as u32))), ) }) - .collect::>(); - let sum = bits.iter().enumerate().fold(self.r1cs.zero(), |s, (i, b)| { - s + &(b.clone() * &Integer::from(2).pow(i as u32)) - }); - self.r1cs - .constraint(self.r1cs.zero(), self.r1cs.zero(), sum - x); + .collect::>() + } + + /// Given wire `x`, returns a vector of `n` wires which are the bits of `x`. + /// Constrains `x` to fit in `n` (`signed`) bits. + /// The LSB is at index 0. + fn bitify(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Vec { + let bits = self.decomp(d, x, n); + let sum = self.debitify(bits.iter().cloned(), signed); + self.assert_zero(sum - x); bits } + /// Given wire `x`, returns whether `x` fits in `n` `signed` bits. + fn fits_in_bits(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Lc { + let bits = self.decomp(d, x, n); + let sum = self.debitify(bits.iter().cloned(), signed); + self.are_equal(sum, x) + } + /// Given a sequence of `bits`, returns a wire which represents their sum, /// `\sum_{i>0} b_i2^i`. /// @@ -164,7 +177,7 @@ impl ToR1cs { fn nary_xor>(&mut self, xs: I) -> Lc { let n = xs.len(); let sum = xs.into_iter().fold(self.r1cs.zero(), |s, i| s + &i); - let sum_bits = self.bitify("sum", &sum, bitsize(n)); + let sum_bits = self.bitify("sum", &sum, bitsize(n), false); assert!(n > 0); assert!(self.r1cs.modulus() > &n); sum_bits.into_iter().next().unwrap() // safe b/c assert @@ -262,46 +275,116 @@ impl ToR1cs { v } Op::Const(Value::Bool(b)) => self.r1cs.zero() + *b as isize, - Op::Eq => self.embed_eq(&c.children[0], &c.children[1]), + Op::Eq => self.embed_eq(&c.cs[0], &c.cs[1]), Op::Ite => { - let a = self.get_bool(&c.children[0]).clone(); - let b = self.get_bool(&c.children[1]).clone(); - let c = self.get_bool(&c.children[2]).clone(); + let a = self.get_bool(&c.cs[0]).clone(); + let b = self.get_bool(&c.cs[1]).clone(); + let c = self.get_bool(&c.cs[2]).clone(); self.ite(a, b, &c) } Op::Not => { - let a = self.get_bool(&c.children[0]); + let a = self.get_bool(&c.cs[0]); self.bool_not(a) } Op::Implies => { - let a = self.get_bool(&c.children[0]).clone(); - let b = self.get_bool(&c.children[1]).clone(); + let a = self.get_bool(&c.cs[0]).clone(); + let b = self.get_bool(&c.cs[1]).clone(); let not_a = self.bool_not(&a); self.nary_or(vec![not_a, b].into_iter()) } Op::BoolNaryOp(o) => { - let args = c - .children - .iter() - .map(|c| self.get_bool(c).clone()) - .collect::>(); + let args = + c.cs.iter() + .map(|c| self.get_bool(c).clone()) + .collect::>(); match o { BoolNaryOp::Or => self.nary_or(args.into_iter()), BoolNaryOp::And => self.nary_and(args.into_iter()), BoolNaryOp::Xor => self.nary_xor(args.into_iter()), } } + Op::BvBit(i) => { + let a = self.get_bv_bits(&c.cs[0]); + a[*i].clone() + } + Op::BvBinPred(o) => { + let n = check(c.cs[0].clone()).unwrap().as_bv(); + use BvBinPred::*; + match o { + Sge => self.bv_cmp(n, true, false, &c.cs[0], &c.cs[1]), + Sgt => self.bv_cmp(n, true, true, &c.cs[0], &c.cs[1]), + Uge => self.bv_cmp(n, false, false, &c.cs[0], &c.cs[1]), + Ugt => self.bv_cmp(n, false, true, &c.cs[0], &c.cs[1]), + Sle => self.bv_cmp(n, true, false, &c.cs[1], &c.cs[0]), + Slt => self.bv_cmp(n, true, true, &c.cs[1], &c.cs[0]), + Ule => self.bv_cmp(n, false, false, &c.cs[1], &c.cs[0]), + Ult => self.bv_cmp(n, false, true, &c.cs[1], &c.cs[0]), + } + } + _ => panic!("Non-boolean in embed_bool: {:?}", c), }; self.bools.insert(c.clone(), lc); } -// self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| { -// println!("-> {}", v); -// }); + // self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| { + // println!("-> {}", v); + // }); self.bools.get(&c).unwrap() } + /// Returns whether `a - b` fits in `size` non-negative bits. + /// i.e. is in `{0, 1, ..., 2^n-1}`. + fn bv_ge(&mut self, a: Lc, b: &Lc, size: usize) -> Lc { + self.fits_in_bits("ge", &(a - b), size, false) + } + + /// Returns whether `a` is (`strict`ly) (`signed`ly) greater than `b`. + /// Assumes they are each `w`-bit bit-vectors. + fn bv_cmp(&mut self, w: usize, signed: bool, strict: bool, a: &Term, b: &Term) -> Lc { + let a = if signed { + self.get_bv_signed_int(a) + } else { + self.get_bv_uint(a).clone() + }; + let b = if signed { + self.get_bv_signed_int(b) + } else { + self.get_bv_uint(b).clone() + }; + // Use the fact: a > b <=> a - 1 >= b + self.bv_ge(if strict { a - 1 } else { a }, &b, w) + } + + /// Shift `x` left by `2^y`, if bit-valued `c` is true. + fn const_pow_shift_bv(&mut self, x: &Lc, y: usize, c: Lc) -> Lc { + self.ite(c, x.clone() * (1 << y), x) + } + + /// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`. + /// Returns an *oversized* number + fn shift_bv(&mut self, x: Lc, y: Vec, ext_bit: Option) -> Lc { + if let Some(b) = ext_bit { + let left = self.shift_bv(x, y.clone(), None); + let right = self.shift_bv(b.clone(), y, None) - 1; + left + &self.mul(b, right) + } else { + y.into_iter() + .enumerate() + .fold(x, |x, (i, yi)| self.const_pow_shift_bv(&x, i, yi)) + } + } + + /// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`. + /// Returns a bit sequence. + fn shift_bv_bits(&mut self, x: Lc, y: Vec, ext_bit: Option, n: usize) -> Vec { + let s = self.shift_bv(x, y, ext_bit); + let mut bits = self.bitify("shift", &s, 2 * n - 1, false); + bits.truncate(n); + bits + } + fn embed_bv(&mut self, bv: Term) { + //println!("Embed: {}", bv); if let Sort::BitVector(n) = check(bv.clone()).unwrap() { if !self.bvs.contains_key(&bv) { match &bv.op { @@ -313,47 +396,53 @@ impl ToR1cs { self.get_bv_bits(&bv); } } - Op::Const(Value::BitVector(BitVector { uint, width })) => { - let bit_lcs = (0..*width) - .map(|i| self.r1cs.zero() + uint.get_bit(i as u32) as isize) + Op::Const(Value::BitVector(b)) => { + let bit_lcs = (0..b.width()) + .map(|i| self.r1cs.zero() + b.uint().get_bit(i as u32) as isize) .collect(); self.set_bv_bits(bv, bit_lcs); } Op::Ite => { - let c = self.get_bool(&bv.children[0]).clone(); - let t = self.get_bv_uint(&bv.children[1]).clone(); - let f = self.get_bv_uint(&bv.children[2]).clone(); + let c = self.get_bool(&bv.cs[0]).clone(); + let t = self.get_bv_uint(&bv.cs[1]).clone(); + let f = self.get_bv_uint(&bv.cs[2]).clone(); let ite = self.ite(c, t, &f); self.set_bv_uint(bv, ite, n); } Op::BvUnOp(BvUnOp::Not) => { - let bits = self.get_bv_bits(&bv.children[0]).clone(); + let bits = self.get_bv_bits(&bv.cs[0]).clone(); let not_bits = bits.iter().map(|bit| self.bool_not(bit)).collect(); self.set_bv_bits(bv, not_bits); } Op::BvUnOp(BvUnOp::Neg) => { - let x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - - self.get_bv_uint(&bv.children[0]); + let x = self.get_bv_uint(&bv.cs[0]).clone(); + // Wrong for x == 0 + let almost_neg_x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - &x; + let is_zero = self.is_zero(x); + let neg_x = self.ite(is_zero, self.r1cs.zero(), &almost_neg_x); + self.set_bv_uint(bv, neg_x, n); + } + Op::BvUext(_) => { + // TODO: carry over bits if possible. + let x = self.get_bv_uint(&bv.cs[0]).clone(); self.set_bv_uint(bv, x, n); } - Op::BvUext(extra_n) => { - // TODO: carry over bits if possible. - let x = self.get_bv_uint(&bv.children[0]).clone(); - let new_n = n + extra_n; - self.set_bv_uint(bv, x, new_n); - } Op::BvSext(extra_n) => { - let mut bits = self.get_bv_bits(&bv.children[0]).clone().into_iter(); + let mut bits = self.get_bv_bits(&bv.cs[0]).clone().into_iter().rev(); let ext_bits = std::iter::repeat(bits.next().expect("sign ext empty").clone()) .take(extra_n + 1); - self.set_bv_bits(bv, ext_bits.chain(bits).collect()); + self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect()); + } + Op::BoolToBv => { + let b = self.get_bool(&bv.cs[0]).clone(); + self.set_bv_bits(bv, vec![b]); } Op::BvNaryOp(o) => match o { BvNaryOp::Xor | BvNaryOp::Or | BvNaryOp::And => { let mut bits_by_bv = bv - .children + .cs .iter() .map(|c| self.get_bv_bits(c).clone()) .collect::>(); @@ -363,6 +452,7 @@ impl ToR1cs { bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(), ); } + bits_bv_idx.reverse(); let f = |v: Vec| match o { BvNaryOp::And => self.nary_and(v.into_iter()), BvNaryOp::Or => self.nary_or(v.into_iter()), @@ -375,7 +465,7 @@ impl ToR1cs { BvNaryOp::Add | BvNaryOp::Mul => { let f_width = self.r1cs.modulus().significant_bits() as usize - 1; let values = bv - .children + .cs .iter() .map(|c| self.get_bv_uint(c).clone()) .collect::>(); @@ -383,21 +473,21 @@ impl ToR1cs { BvNaryOp::Add => { let sum = values.into_iter().fold(self.r1cs.zero(), |s, v| s + &v); - let extra_width = bitsize(bv.children.len()) - 1; + let extra_width = bitsize(bv.cs.len().saturating_sub(1)); (sum, n + extra_width) } BvNaryOp::Mul => { - if bv.children.len() * n < f_width { + if bv.cs.len() * n < f_width { let z = self.r1cs.zero() + 1; ( values.into_iter().fold(z, |acc, v| self.mul(acc, v)), - bv.children.len() * n, + bv.cs.len() * n, ) } else { let z = self.r1cs.zero() + 1; let p = values.into_iter().fold(z, |acc, v| { let p = self.mul(acc, v); - let mut bits = self.bitify("binMul", &p, 2 * n); + let mut bits = self.bitify("binMul", &p, 2 * n, false); bits.truncate(n); self.debitify(bits.into_iter(), false) }); @@ -406,17 +496,94 @@ impl ToR1cs { } _ => unreachable!(), }; - let mut bits = self.bitify("arith", &res, width); + let mut bits = self.bitify("arith", &res, width, false); bits.truncate(n); self.set_bv_bits(bv, bits); } }, - Op::BvBinOp(o) => unimplemented!(), - _ => panic!("Non-boolean in embed_bool: {:?}", bv), + Op::BvBinOp(o) => { + let a = self.get_bv_uint(&bv.cs[0]); + let b = self.get_bv_uint(&bv.cs[1]); + match o { + BvBinOp::Sub => { + let sum = a.clone() + &(Integer::from(1) << n as u32) - b; + let mut bits = self.bitify("sub", &sum, n + 1, false); + bits.truncate(n); + self.set_bv_bits(bv, bits); + } + BvBinOp::Udiv | BvBinOp::Urem => { + let b = b.clone(); + let a = a.clone(); + let is_zero = self.is_zero(b.clone()); + let (q_v, r_v) = self + .r1cs + .eval(&a) + .and_then(|a| { + self.r1cs.eval(&b).map(|b| { + if b == 0 { + ((Integer::from(1) << n as u32) - 1, a) + } else { + (a.clone() / &b, a % b) + } + }) + }) + .map(|(a, b)| (Some(a), Some(b))) + .unwrap_or((None, None)); + let q = self.fresh_var("div_q", q_v); + let r = self.fresh_var("div_q", r_v); + let qb = self.bitify("div_q", &q, n, false); + let rb = self.bitify("div_r", &q, n, false); + self.r1cs.constraint(q.clone(), b.clone(), a - &r); + let is_gt = self.bv_ge(b - 1, &r, n); + let is_not_ge = self.bool_not(&is_gt); + let is_not_zero = self.bool_not(&is_zero); + self.r1cs + .constraint(is_not_ge, is_not_zero, self.r1cs.zero()); + let bits = match o { + BvBinOp::Udiv => qb, + BvBinOp::Urem => rb, + _ => unreachable!(), + }; + self.set_bv_bits(bv, bits); + } + // Shift cases + _ => { + let r = b.clone(); + let a = a.clone(); + let b = bitsize(n - 1); + assert!(1 << b == n); + let mut rb = self.get_bv_bits(&bv.cs[1]).clone(); + rb.truncate(b); + let sum = self.debitify(rb.clone().into_iter(), false); + self.assert_zero(sum - &r); + let bits = match o { + BvBinOp::Shl => self.shift_bv_bits(a, rb, None, n), + BvBinOp::Lshr | BvBinOp::Ashr => { + let mut lb = self.get_bv_bits(&bv.cs[0]).clone(); + lb.reverse(); + let ext_bit = match o { + BvBinOp::Ashr => Some(lb.first().unwrap().clone()), + _ => None, + }; + let l = self.debitify(lb.into_iter(), false); + let mut bits = self.shift_bv_bits(l, rb, ext_bit, n); + bits.reverse(); + bits + } + _ => unreachable!(), + }; + self.set_bv_bits(bv, bits); + } + } + } + _ => panic!("Non-bv in embed_bv: {}", bv), } } + // self.r1cs.eval(self.get_bv_uint(&bv2)).map(|v| { + // println!("-> {:b}", v); + // }); } else { - panic!("{:?} is not a bit-vector in embed_bv", bv); + panic!("{} is not a bit-vector in embed_bv", bv); } } @@ -455,21 +622,28 @@ impl ToR1cs { &self.bvs.get(t).expect("Missing term").uint } + fn get_bv_signed_int(&mut self, t: &Term) -> Lc { + let bits = self.get_bv_bits(t).clone(); + self.debitify(bits.into_iter(), true) + } + fn get_bv_bits(&mut self, t: &Term) -> &Vec { let mut bvs = std::mem::take(&mut self.bvs); let entry = bvs.get_mut(t).expect("Missing bit-vec"); if entry.bits.len() == 0 { - entry.bits = self.bitify("getbits", &entry.uint, entry.width); + entry.bits = self.bitify("getbits", &entry.uint, entry.width, false); } self.bvs = bvs; &self.bvs.get(t).unwrap().bits } + fn assert_zero(&mut self, x: Lc) { + self.r1cs.constraint(self.r1cs.zero(), self.r1cs.zero(), x); + } fn assert(&mut self, t: Term) { self.embed(t.clone()); let lc = self.get_bool(&t).clone(); - self.r1cs - .constraint(self.r1cs.zero(), self.r1cs.zero(), lc - 1); + self.assert_zero(lc - 1); } } @@ -494,6 +668,7 @@ fn bitsize(mut n: usize) -> usize { #[cfg(test)] mod test { use super::*; + use crate::ir::term::dist::*; use quickcheck::{Arbitrary, Gen}; use quickcheck_macros::quickcheck; use rand::distributions::Distribution; @@ -526,7 +701,7 @@ mod test { impl Arbitrary for PureBool { fn arbitrary(g: &mut Gen) -> Self { let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g)); - let t = BoolDist(g.size()).sample(&mut rng); + let t = PureBoolDist(g.size()).sample(&mut rng); let values: HashMap = PostOrderIter::new(t.clone()) .filter_map(|c| { if let Op::Var(n, _) = &c.op { @@ -553,7 +728,7 @@ mod test { } #[quickcheck] - fn random_bool(PureBool(t, values): PureBool) { + fn random_pure_bool(PureBool(t, values): PureBool) { let t = if eval(&t, &values).as_bool() { t } else { @@ -567,4 +742,76 @@ mod test { let r1cs = to_r1cs(cs, Integer::from(1014088787)); r1cs.check_all(); } + + #[derive(Clone, Debug)] + struct Bool(Term, HashMap); + + impl Arbitrary for Bool { + fn arbitrary(g: &mut Gen) -> Self { + let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g)); + let d = FixedSizeDist { + bv_width: 8, + size: g.size(), + sort: Sort::Bool, + }; + let t = d.sample(&mut rng); + let values: HashMap = PostOrderIter::new(t.clone()) + .filter_map(|c| match &c.op { + Op::Var(n, Sort::Bool) => Some((n.clone(), Value::Bool(bool::arbitrary(g)))), + Op::Var(n, Sort::BitVector(w)) => Some(( + n.clone(), + Value::BitVector(UniformBitVector(*w).sample(&mut rng)), + )), + _ => None, + }) + .collect(); + Bool(t, values) + } + + fn shrink(&self) -> Box> { + let vs = self.1.clone(); + let ts = PostOrderIter::new(self.0.clone()).collect::>(); + + Box::new( + ts.into_iter() + .rev() + .skip(1) + .map(move |t| Bool(t, vs.clone())), + ) + } + } + + #[quickcheck] + fn random_bool(Bool(t, values): Bool) { + let v = eval(&t, &values); + let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; + let cs = Constraints { + public_inputs: HashSet::new(), + values: Some(values), + assertions: vec![t], + }; + let r1cs = to_r1cs(cs, Integer::from(1014088787)); + r1cs.check_all(); + } + + #[test] + fn eq_test() { + let cs = Constraints { + public_inputs: vec!["a"].into_iter().map(|a| a.to_owned()).collect(), + values: Some( + vec![( + "b".to_owned(), + Value::BitVector(BitVector::new(Integer::from(152), 8)), + )] + .into_iter() + .collect(), + ), + assertions: vec![ + term![Op::Not; term![Op::Eq; leaf_term(Op::Const(Value::BitVector(BitVector::new(Integer::from(0b10110), 8)))), + term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]], + ], + }; + let r1cs = to_r1cs(cs, Integer::from(1014088787)); + r1cs.check_all(); + } } diff --git a/src/lib.rs b/src/lib.rs index 25ac40cb..53525974 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,5 +2,5 @@ #[macro_use] pub mod ir; -pub mod util; pub mod target; +pub mod util; diff --git a/src/target/r1cs/mod.rs b/src/target/r1cs/mod.rs index 2c99f1b9..e064f4a5 100644 --- a/src/target/r1cs/mod.rs +++ b/src/target/r1cs/mod.rs @@ -1,4 +1,5 @@ use rug::Integer; +use rug::ops::{RemRounding, RemRoundingAssign}; use std::collections::HashMap; use std::collections::HashSet; use std::fmt::Display; @@ -34,13 +35,13 @@ impl std::ops::AddAssign<&Lc> for Lc { fn add_assign(&mut self, other: &Lc) { assert_eq!(&self.modulus, &other.modulus); self.constant += &other.constant; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); for (i, v) in &other.monomials { self.monomials .entry(*i) .and_modify(|u| { *u += v; - *u %= &*other.modulus; + u.rem_floor_assign(&*other.modulus); }) .or_insert_with(|| v.clone()); } @@ -58,7 +59,7 @@ impl std::ops::Add<&Integer> for Lc { impl std::ops::AddAssign<&Integer> for Lc { fn add_assign(&mut self, other: &Integer) { self.constant += other; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); } } @@ -73,7 +74,7 @@ impl std::ops::Add for Lc { impl std::ops::AddAssign for Lc { fn add_assign(&mut self, other: isize) { self.constant += Integer::from(other); - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); } } @@ -89,13 +90,13 @@ impl std::ops::SubAssign<&Lc> for Lc { fn sub_assign(&mut self, other: &Lc) { assert_eq!(&self.modulus, &other.modulus); self.constant -= &other.constant; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); for (i, v) in &other.monomials { self.monomials .entry(*i) .and_modify(|u| { *u -= v; - *u %= &*other.modulus; + u.rem_floor_assign(&*other.modulus); }) .or_insert_with(|| -v.clone()); } @@ -113,7 +114,7 @@ impl std::ops::Sub<&Integer> for Lc { impl std::ops::SubAssign<&Integer> for Lc { fn sub_assign(&mut self, other: &Integer) { self.constant -= other; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); } } @@ -128,7 +129,7 @@ impl std::ops::Sub for Lc { impl std::ops::SubAssign for Lc { fn sub_assign(&mut self, other: isize) { self.constant -= Integer::from(other); - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); } } @@ -136,10 +137,10 @@ impl std::ops::Neg for Lc { type Output = Lc; fn neg(mut self) -> Lc { self.constant = -self.constant; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); for (_, v) in &mut self.monomials { *v *= Integer::from(-1); - *v %= &*self.modulus; + v.rem_floor_assign(&*self.modulus); } self } @@ -156,10 +157,10 @@ impl std::ops::Mul<&Integer> for Lc { impl std::ops::MulAssign<&Integer> for Lc { fn mul_assign(&mut self, other: &Integer) { self.constant *= other; - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); for (_, v) in &mut self.monomials { *v *= other; - *v %= &*self.modulus; + v.rem_floor_assign(&*self.modulus); } } } @@ -175,10 +176,10 @@ impl std::ops::Mul for Lc { impl std::ops::MulAssign for Lc { fn mul_assign(&mut self, other: isize) { self.constant *= Integer::from(other); - self.constant %= &*self.modulus; + self.constant.rem_floor_assign(&*self.modulus); for (_, v) in &mut self.monomials { *v *= Integer::from(other); - *v %= &*self.modulus; + v.rem_floor_assign(&*self.modulus); } } } @@ -253,7 +254,7 @@ impl R1cs { let sign = |i: &Integer| if i < &half_m { "+" } else { "-" }; let format_i = |i: &Integer| format!("{}{}", sign(i), abs(i)); - s.extend(format_i(&a.constant).chars()); + s.extend(format_i(&Integer::from(&a.constant)).chars()); for (idx, coeff) in &a.monomials { s.extend( format!( @@ -271,7 +272,8 @@ impl R1cs { let av = self.eval(a).unwrap(); let bv = self.eval(b).unwrap(); let cv = self.eval(c).unwrap(); - if &(av.clone() * &bv % &*self.modulus) != &cv { + dbg!(&av, &bv, &cv, &self.modulus); + if &((av.clone() * &bv).rem_floor(&*self.modulus)) != &cv { panic!( "Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})", self.format_lc(a), @@ -292,7 +294,7 @@ impl R1cs { .expect("Missing value in R1cs::eval") .clone(); acc += val * coeff; - acc %= &*self.modulus; + acc.rem_floor_assign(&*self.modulus); } acc }) diff --git a/src/util/hc/mod.rs b/src/util/hc/mod.rs index e69de29b..8b137891 100644 --- a/src/util/hc/mod.rs +++ b/src/util/hc/mod.rs @@ -0,0 +1 @@ +