Reasonably complete BV->R1CS

This commit is contained in:
Alex Ozdemir
2021-02-06 22:48:01 -08:00
parent 224b056ef4
commit ec9f1c464a
8 changed files with 869 additions and 268 deletions

View File

@@ -1,11 +1,15 @@
use circ::ir;
use rand::distributions::Distribution;
fn main() {
let mut rng = rand::thread_rng();
for i in 0..100 {
let t = ir::term::BoolDist(6).sample(&mut rng);
println!("Term: {:#?}", t)
for _ in 0..100 {
let d = ir::term::dist::FixedSizeDist {
bv_width: 4,
sort: ir::term::Sort::Bool,
size: 6,
};
let t = d.sample(&mut rng);
println!("Term: {}", t)
}
}

208
src/ir/term/bv.rs Normal file
View File

@@ -0,0 +1,208 @@
use rug::Integer;
use std::fmt::{self, Display, Formatter};
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BitVector {
uint: Integer,
width: usize,
}
macro_rules! bv_arith_impl {
($Trait:path, $fn:ident) => {
impl $Trait for BitVector {
type Output = Self;
fn $fn(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
let r = BitVector {
uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32),
width: self.width,
};
r.check(std::stringify!($fn));
r
}
}
};
}
bv_arith_impl!(std::ops::Add, add);
bv_arith_impl!(std::ops::Sub, sub);
bv_arith_impl!(std::ops::Mul, mul);
//bv_arith_impl!(std::ops::Div, div);
//bv_arith_impl!(std::ops::Rem, rem);
bv_arith_impl!(std::ops::BitAnd, bitand);
bv_arith_impl!(std::ops::BitOr, bitor);
bv_arith_impl!(std::ops::BitXor, bitxor);
/// SMT-semantics implementation of unsigned division (udiv).
///
/// If the divisor is zero, returns the all-ones vector.
impl std::ops::Div<&BitVector> for BitVector {
type Output = Self;
fn div(self, other: &Self) -> Self {
assert_eq!(self.width, other.width);
if other.uint == 0 {
let r = BitVector {
uint: (Integer::from(1) << self.width as u32) - 1,
width: self.width,
};
r.check("div");
r
} else {
let r = BitVector {
uint: self.uint / &other.uint,
width: self.width,
};
r.check("div");
r
}
}
}
/// SMT-semantics implementation of unsigned remainder (urem).
///
/// If the divisor is zero, returns the all-ones vector.
impl std::ops::Rem<&BitVector> for BitVector {
type Output = Self;
fn rem(self, other: &Self) -> Self {
assert_eq!(self.width, other.width);
if other.uint == 0 {
self
} else {
let r = BitVector {
uint: self.uint % &other.uint,
width: self.width,
};
r.check("rem");
r
}
}
}
impl std::ops::Shl for BitVector {
type Output = Self;
fn shl(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
let r = BitVector {
uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32),
width: self.width,
};
r.check("shl");
r
}
}
impl std::ops::Neg for BitVector {
type Output = Self;
fn neg(self) -> Self {
let r = BitVector {
uint: ((Integer::from(1) << self.width as u32) - self.uint)
.keep_bits(self.width as u32),
width: self.width,
};
r.check("neg");
r
}
}
impl std::ops::Not for BitVector {
type Output = Self;
fn not(self) -> Self {
let r = BitVector {
uint: (Integer::from(1) << self.width as u32) - 1 - self.uint,
width: self.width,
};
r.check("not");
r
}
}
impl BitVector {
#[track_caller]
#[inline]
pub fn check(&self, location: &str) {
debug_assert!(
self.uint >= 0,
"Too small bitvector: {:?}, {}\n at {}",
self,
self.uint.significant_bits(),
location
);
debug_assert!(
(self.uint.significant_bits() as usize) <= self.width,
"Too big bitvector: {:?}, {}\n at {}",
self,
self.uint.significant_bits(),
location
);
}
pub fn ashr(mut self, other: Self) -> Self {
assert_eq!(self.width, other.width);
let n = other.uint.to_u32().unwrap();
let b = self.uint.get_bit(self.width as u32 - 1);
self.uint >>= n;
for i in 0..n {
self.uint.set_bit(self.width as u32 - 1 - i, b);
}
self.check("ashr");
self
}
pub fn lshr(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
let r = BitVector {
uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32),
width: self.width,
};
r.check("lshr");
r
}
pub fn concat(self, other: Self) -> Self {
let r = BitVector {
uint: (self.uint << other.width as u32) | other.uint,
width: self.width + other.width,
};
r.check("concat");
r
}
pub fn extract(self, high: usize, low: usize) -> Self {
let r = BitVector {
uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32),
width: high - low + 1,
};
r.check("extract");
r
}
pub fn as_sint(&self) -> Integer {
if self.uint.significant_bits() as usize == self.width {
self.uint.clone() - (Integer::from(1) << self.width as u32)
} else {
self.uint.clone()
}
}
pub fn uint(&self) -> &Integer {
&self.uint
}
pub fn width(&self) -> usize {
self.width
}
#[track_caller]
pub fn new(uint: Integer, width: usize) -> BitVector {
let r = BitVector { uint, width };
r.check("new");
r
}
}
impl Display for BitVector {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "#b")?;
for i in 0..self.width {
write!(
f,
"{}",
self.uint.get_bit((self.width - i - 1) as u32) as u8
)?;
}
Ok(())
}
}

207
src/ir/term/dist.rs Normal file
View File

@@ -0,0 +1,207 @@
use super::*;
// A distribution of boolean terms with some size.
// All subterms are booleans.
pub struct PureBoolDist(pub usize);
// A distribution of n usizes that sum to this value.
// (n, sum)
pub struct Sum(usize, usize);
impl rand::distributions::Distribution<Vec<usize>> for Sum {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<usize> {
use rand::seq::SliceRandom;
let mut acc = self.1;
let mut ns = Vec::new();
assert!(acc == 0 || self.0 > 0);
while acc > 0 && ns.len() < self.0 {
let x = rng.gen_range(0..acc);
acc -= x;
ns.push(x);
}
while ns.len() < self.0 {
ns.push(0);
}
if acc > 0 {
*ns.last_mut().unwrap() += acc;
}
ns.shuffle(rng);
ns
}
}
impl rand::distributions::Distribution<Term> for PureBoolDist {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
use rand::seq::SliceRandom;
let ops = &[
Op::Const(Value::Bool(rng.gen())),
Op::Var(
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
.unwrap()
.to_owned(),
Sort::Bool,
),
Op::Not,
Op::Implies,
Op::BoolNaryOp(BoolNaryOp::Or),
Op::BoolNaryOp(BoolNaryOp::And),
Op::BoolNaryOp(BoolNaryOp::Xor),
];
let o = match self.0 {
1 => ops[..2].choose(rng), // arity 0
2 => ops[2..3].choose(rng), // arity 1
_ => ops[2..].choose(rng), // others
}
.unwrap()
.clone();
// Now, self.0 is a least arity+1
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
let excess = self.0 - 1 - a;
let ns = Sum(a, excess).sample(rng);
let subterms = ns
.into_iter()
.map(|n| PureBoolDist(n + 1).sample(rng))
.collect::<Vec<_>>();
term(o, subterms)
}
}
pub struct FixedSizeDist {
pub size: usize,
pub bv_width: usize,
pub sort: Sort,
}
impl FixedSizeDist {
fn with_size(&self, size: usize) -> Self {
FixedSizeDist {
size,
sort: self.sort.clone(),
bv_width: self.bv_width,
}
}
fn with_sort(&self, sort: Sort) -> Self {
FixedSizeDist {
size: self.size,
sort,
bv_width: self.bv_width,
}
}
}
pub struct UniformBitVector(pub usize);
impl rand::distributions::Distribution<BitVector> for UniformBitVector {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BitVector {
let mut rug_rng = rug::rand::RandState::new_mersenne_twister();
rug_rng.seed(&Integer::from(rng.next_u32()));
BitVector::new(
Integer::from(Integer::random_bits(self.0 as u32, &mut rug_rng)),
self.0,
)
}
}
impl rand::distributions::Distribution<Term> for FixedSizeDist {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
use rand::seq::SliceRandom;
match self.sort.clone() {
Sort::Bool => {
let ops = &[
Op::Const(Value::Bool(rng.gen())),
Op::Var(
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
.unwrap()
.to_owned(),
Sort::Bool,
),
Op::Not, // 2
Op::Implies,
Op::Eq,
Op::BvBinPred(BvBinPred::Sge),
Op::BvBinPred(BvBinPred::Sgt),
Op::BvBinPred(BvBinPred::Sle),
Op::BvBinPred(BvBinPred::Slt),
Op::BvBinPred(BvBinPred::Uge),
Op::BvBinPred(BvBinPred::Ugt),
Op::BvBinPred(BvBinPred::Ule),
Op::BvBinPred(BvBinPred::Ult),
Op::BoolNaryOp(BoolNaryOp::Or),
Op::BoolNaryOp(BoolNaryOp::And),
Op::BoolNaryOp(BoolNaryOp::Xor),
Op::Ite,
];
let o = match self.size {
1 => ops[..2].choose(rng), // arity 0
2 => ops[2..3].choose(rng), // arity 1
3 => ops[2..16].choose(rng), // arity 2
_ => ops[2..].choose(rng), // others
}
.unwrap()
.clone();
// Now, self.0 is a least arity+1
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size));
let excess = self.size - 1 - a;
let ns = Sum(a, excess).sample(rng);
let sort = match o {
Op::Eq => [Sort::Bool, Sort::BitVector(self.bv_width)]
.choose(rng)
.unwrap()
.clone(),
Op::BvBinPred(_) => Sort::BitVector(self.bv_width),
_ => Sort::Bool,
};
let subterms = ns
.into_iter()
.map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng))
.collect::<Vec<_>>();
term(o, subterms)
}
Sort::BitVector(w) => {
let ops = &[
Op::Const(Value::BitVector(UniformBitVector(w).sample(rng))),
Op::Var(
format!(
"{}_bv{}",
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]).unwrap(),
w
),
Sort::BitVector(w),
),
Op::BvUnOp(BvUnOp::Neg),
Op::BvUnOp(BvUnOp::Not),
Op::BvUext(rng.gen_range(0..w)),
Op::BvSext(rng.gen_range(0..w)),
Op::BvBinOp(BvBinOp::Sub),
Op::BvNaryOp(BvNaryOp::Or),
Op::BvNaryOp(BvNaryOp::And),
Op::BvNaryOp(BvNaryOp::Xor),
Op::BvNaryOp(BvNaryOp::Add),
Op::BvNaryOp(BvNaryOp::Mul),
// Add ITEs
];
let o = match self.size {
1 => ops[..2].choose(rng), // arity 0
2 => ops[2..4].choose(rng), // arity 1
_ => ops[2..].choose(rng), // others
}
.unwrap()
.clone();
let sort = match o {
Op::BvUext(ww) => Sort::BitVector(w - ww),
Op::BvSext(ww) => Sort::BitVector(w - ww),
_ => Sort::BitVector(w),
};
// Now, self.0 is a least arity+1
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size));
let excess = self.size - 1 - a;
let ns = Sum(a, excess).sample(rng);
let subterms = ns
.into_iter()
.map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng))
.collect::<Vec<_>>();
term(o, subterms)
}
s => panic!("Unsampleabl sort: {}", s),
}
}
}

View File

@@ -6,6 +6,11 @@ use std::collections::HashSet;
use std::fmt::{self, Display, Formatter};
use std::sync::{Arc, RwLock};
pub mod bv;
pub mod dist;
pub use bv::BitVector;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Op {
Ite,
@@ -352,16 +357,16 @@ impl Display for PfUnOp {
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct TermData {
pub op: Op,
pub children: Vec<Term>,
pub cs: Vec<Term>,
}
impl Display for TermData {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if let Op::Var(..) = &self.op {
if self.op.arity() == Some(0) {
write!(f, "{}", self.op)
} else {
write!(f, "({}", self.op)?;
for c in &self.children {
for c in &self.cs {
write!(f, " {}", c)?;
}
write!(f, ")")
@@ -385,90 +390,6 @@ impl Display for FieldElem {
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct BitVector {
pub uint: Integer,
pub width: usize,
}
macro_rules! bv_arith_impl {
($Trait:path, $fn:ident) => {
impl $Trait for BitVector {
type Output = Self;
fn $fn(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
BitVector {
uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32),
width: self.width,
}
}
}
};
}
bv_arith_impl!(std::ops::Add, add);
bv_arith_impl!(std::ops::Sub, sub);
bv_arith_impl!(std::ops::Mul, mul);
bv_arith_impl!(std::ops::Div, div);
bv_arith_impl!(std::ops::Rem, rem);
impl std::ops::Shl for BitVector {
type Output = Self;
fn shl(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
BitVector {
uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32),
width: self.width,
}
}
}
impl BitVector {
pub fn ashr(mut self, other: Self) -> Self {
assert_eq!(self.width, other.width);
let n = other.uint.to_u32().unwrap();
let b = self.uint.get_bit(self.width as u32 - 1);
self.uint >>= n;
for i in 0..n {
self.uint.set_bit(self.width as u32 - 1 - i, b);
}
self
}
pub fn lshr(self, other: Self) -> Self {
assert_eq!(self.width, other.width);
BitVector {
uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32),
width: self.width,
}
}
pub fn concat(self, other: Self) -> Self {
BitVector {
uint: (self.uint << other.width as u32) | other.uint,
width: self.width + other.width,
}
}
pub fn extract(self, high: usize, low: usize) -> Self {
BitVector {
uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32),
width: high - low + 1,
}
}
}
impl Display for BitVector {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "#b")?;
for i in 0..self.width {
write!(
f,
"#{}",
self.uint.get_bit((self.width - i - 1) as u32) as u8
)?;
}
Ok(())
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum Value {
BitVector(BitVector),
@@ -517,6 +438,30 @@ pub enum Sort {
Array(Box<Sort>, Box<Sort>),
}
impl Sort {
pub fn as_bv(&self) -> usize {
if let Sort::BitVector(w) = self {
*w
} else {
panic!("{} is not a bit-vector", self)
}
}
}
impl Display for Sort {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Sort::Bool => write!(f, "bool"),
Sort::BitVector(n) => write!(f, "(bv {})", n),
Sort::Int => write!(f, "int"),
Sort::F32 => write!(f, "f32"),
Sort::F64 => write!(f, "f64"),
Sort::Field(i) => write!(f, "(mod {})", i),
Sort::Array(k, v) => write!(f, "(array {} {})", k, v),
}
}
}
pub type Term = HConsed<TermData>;
// "Temporary" terms.
pub type TTerm = WHConsed<TermData>;
@@ -537,9 +482,10 @@ impl Value {
Value::Int(_) => Sort::Int,
Value::F64(_) => Sort::F64,
Value::F32(_) => Sort::F32,
Value::BitVector(b) => Sort::BitVector(b.width),
Value::BitVector(b) => Sort::BitVector(b.width()),
}
}
#[track_caller]
pub fn as_bool(&self) -> bool {
if let Value::Bool(b) = self {
*b
@@ -547,6 +493,7 @@ impl Value {
panic!("Not a bool: {}", self)
}
}
#[track_caller]
pub fn as_bv(&self) -> &BitVector {
if let Value::BitVector(b) = self {
b
@@ -621,7 +568,7 @@ pub fn check(t: Term) -> Result<Sort, TypeError> {
}
{
let mut term_tys = TERM_TYPES.write().unwrap();
// to_check is a stack of (node, children checked) pairs.
// to_check is a stack of (node, cs checked) pairs.
let mut to_check = vec![(t.clone(), false)];
while to_check.len() > 0 {
let back = to_check.last_mut().unwrap();
@@ -640,13 +587,13 @@ pub fn check(t: Term) -> Result<Sort, TypeError> {
}
if !back.1 {
back.1 = true;
for c in back.0.children.clone() {
for c in back.0.cs.clone() {
to_check.push((c, false));
}
} else {
let tys = back
.0
.children
.cs
.iter()
.map(|c| term_tys.get(&c.to_weak()).unwrap())
.collect::<Vec<_>>();
@@ -781,51 +728,101 @@ pub fn eval(t: &Term, h: &HashMap<String, Value>) -> Value {
for c in PostOrderIter::new(t.clone()) {
let v = match &c.op {
Op::Var(n, _) => h.get(n).unwrap().clone(),
Op::Eq => {
Value::Bool(vs.get(&c.children[0]).unwrap() == vs.get(&c.children[1]).unwrap())
}
Op::Not => Value::Bool(!vs.get(&c.children[0]).unwrap().as_bool()),
Op::Eq => Value::Bool(vs.get(&c.cs[0]).unwrap() == vs.get(&c.cs[1]).unwrap()),
Op::Not => Value::Bool(!vs.get(&c.cs[0]).unwrap().as_bool()),
Op::Implies => Value::Bool(
!vs.get(&c.children[0]).unwrap().as_bool()
|| vs.get(&c.children[1]).unwrap().as_bool(),
!vs.get(&c.cs[0]).unwrap().as_bool() || vs.get(&c.cs[1]).unwrap().as_bool(),
),
Op::BoolNaryOp(BoolNaryOp::Or) => {
Value::Bool(c.children.iter().any(|c| vs.get(c).unwrap().as_bool()))
Value::Bool(c.cs.iter().any(|c| vs.get(c).unwrap().as_bool()))
}
Op::BoolNaryOp(BoolNaryOp::And) => {
Value::Bool(c.children.iter().all(|c| vs.get(c).unwrap().as_bool()))
Value::Bool(c.cs.iter().all(|c| vs.get(c).unwrap().as_bool()))
}
Op::BoolNaryOp(BoolNaryOp::Xor) => Value::Bool(
c.children
.iter()
c.cs.iter()
.map(|c| vs.get(c).unwrap().as_bool())
.fold(false, std::ops::BitXor::bitxor),
),
Op::BvBit(i) => Value::Bool(
vs.get(&c.children[0])
.unwrap()
.as_bv()
.uint
.get_bit(*i as u32),
),
Op::BvBit(i) => {
Value::Bool(vs.get(&c.cs[0]).unwrap().as_bv().uint().get_bit(*i as u32))
}
Op::BvConcat => Value::BitVector({
let mut it = c
.children
.iter()
.map(|c| vs.get(c).unwrap().as_bv().clone());
let mut it = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone());
let f = it.next().unwrap();
it.fold(f, BitVector::concat)
}),
Op::BvExtract(h, l) => Value::BitVector(
vs.get(&c.children[0])
.unwrap()
.as_bv()
.clone()
.extract(*h, *l),
),
Op::BvExtract(h, l) => {
Value::BitVector(vs.get(&c.cs[0]).unwrap().as_bv().clone().extract(*h, *l))
}
Op::Const(v) => v.clone(),
Op::BvBinOp(o) => Value::BitVector({
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
let b = vs.get(&c.cs[1]).unwrap().as_bv().clone();
match o {
BvBinOp::Udiv => a / &b,
BvBinOp::Urem => a % &b,
BvBinOp::Sub => a - b,
BvBinOp::Ashr => a.ashr(b),
BvBinOp::Lshr => a.lshr(b),
BvBinOp::Shl => a << b,
}
}),
Op::BvUnOp(o) => Value::BitVector({
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
match o {
BvUnOp::Not => !a,
BvUnOp::Neg => -a,
}
}),
Op::BvNaryOp(o) => Value::BitVector({
let mut xs = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone());
let f = xs.next().unwrap();
xs.fold(
f,
match o {
BvNaryOp::Add => std::ops::Add::add,
BvNaryOp::Mul => std::ops::Mul::mul,
BvNaryOp::Xor => std::ops::BitXor::bitxor,
BvNaryOp::Or => std::ops::BitOr::bitor,
BvNaryOp::And => std::ops::BitAnd::bitand,
},
)
}),
Op::BvSext(w) => Value::BitVector({
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
let mask = ((Integer::from(1) << *w as u32) - 1)
* Integer::from(a.uint().get_bit(a.width() as u32 - 1));
BitVector::new(a.uint() | (mask << a.width() as u32), a.width() + w)
}),
Op::BvUext(w) => Value::BitVector({
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
BitVector::new(a.uint().clone(), a.width() + w)
}),
Op::Ite => if vs.get(&c.cs[0]).unwrap().as_bool() {
vs.get(&c.cs[1])
} else {
vs.get(&c.cs[2])
}
.unwrap()
.clone(),
Op::BvBinPred(o) => Value::Bool({
let a = vs.get(&c.cs[0]).unwrap().as_bv();
let b = vs.get(&c.cs[1]).unwrap().as_bv();
match o {
BvBinPred::Sge => a.as_sint() >= b.as_sint(),
BvBinPred::Sgt => a.as_sint() > b.as_sint(),
BvBinPred::Sle => a.as_sint() <= b.as_sint(),
BvBinPred::Slt => a.as_sint() < b.as_sint(),
BvBinPred::Uge => a.uint() >= b.uint(),
BvBinPred::Ugt => a.uint() > b.uint(),
BvBinPred::Ule => a.uint() <= b.uint(),
BvBinPred::Ult => a.uint() < b.uint(),
}
}),
o => unimplemented!("eval: {:?}", o),
};
//println!("Eval {}\nAs {}", c, v);
vs.insert(c.clone(), v);
}
vs.get(t).unwrap().clone()
@@ -854,9 +851,9 @@ pub fn leaf_term(op: Op) -> Term {
term(op, Vec::new())
}
pub fn term(op: Op, children: Vec<Term>) -> Term {
pub fn term(op: Op, cs: Vec<Term>) -> Term {
use hashconsing::HashConsign;
let t = TERM_FACTORY.mk(TermData { op, children });
let t = TERM_FACTORY.mk(TermData { op, cs });
check(t.clone()).unwrap();
t
}
@@ -868,76 +865,11 @@ macro_rules! term {
};
}
// A distribution of boolean terms with some size.
// All subterms are booleans.
pub struct BoolDist(pub usize);
// A distribution of n usizes that sum to this value.
// (n, sum)
pub struct Sum(usize, usize);
impl rand::distributions::Distribution<Vec<usize>> for Sum {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<usize> {
use rand::seq::SliceRandom;
let mut acc = self.1;
let mut ns = Vec::new();
assert!(acc == 0 || self.0 > 0);
while acc > 0 && ns.len() < self.0 {
let x = rng.gen_range(0..acc);
acc -= x;
ns.push(x);
}
while ns.len() < self.0 {
ns.push(0);
}
if acc > 0 {
*ns.last_mut().unwrap() += acc;
}
ns.shuffle(rng);
ns
}
}
impl rand::distributions::Distribution<Term> for BoolDist {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
use rand::seq::SliceRandom;
let ops = &[
Op::Const(Value::Bool(rng.gen())),
Op::Var(
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
.unwrap()
.to_owned(),
Sort::Bool,
),
Op::Not,
Op::Implies,
Op::BoolNaryOp(BoolNaryOp::Or),
Op::BoolNaryOp(BoolNaryOp::And),
Op::BoolNaryOp(BoolNaryOp::Xor),
];
let o = match self.0 {
1 => ops[..2].choose(rng), // arity 0
2 => ops[2..3].choose(rng), // arity 1
_ => ops[2..].choose(rng), // others
}
.unwrap()
.clone();
// Now, self.0 is a least arity+1
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
let excess = self.0 - 1 - a;
let ns = Sum(a, excess).sample(rng);
let subterms = ns
.into_iter()
.map(|n| BoolDist(n + 1).sample(rng))
.collect::<Vec<_>>();
term(o, subterms)
}
}
pub type TermMap<T> = hashconsing::coll::HConMap<Term, T>;
pub type TermSet = hashconsing::coll::HConSet<Term>;
pub struct PostOrderIter {
// (children stacked, term)
// (cs stacked, term)
stack: Vec<(bool, Term)>,
visited: TermSet,
}
@@ -959,7 +891,7 @@ impl std::iter::Iterator for PostOrderIter {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
let cs = self.stack.last().unwrap().1.children.clone();
let cs = self.stack.last().unwrap().1.cs.clone();
self.stack.extend(cs.into_iter().map(|c| (false, c)));
} else {
break;

View File

@@ -118,32 +118,45 @@ impl ToR1cs {
self.values
.as_ref()
.map(|vs| match vs.get(var).expect("missing value") {
Value::BitVector(BitVector { uint, .. }) => uint.clone(),
Value::BitVector(b) => b.uint().clone(),
v => panic!("{} should be a bool, but is {:?}", var, v),
})
}
/// Given wire `x`, returns a vector of `n` wires which are the bits of `x`.
/// Constrains `x` to fit in `n` (unsigned) bits.
/// They *have not* been constrained to sum to `x`.
/// They have values according the the (infinite) two's complement representation of `x`.
/// The LSB is at index 0.
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize) -> Vec<Lc> {
fn decomp<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize) -> Vec<Lc> {
let x_val = self.r1cs.eval(x);
let bits = (0..n)
(0..n)
.map(|i| {
self.fresh_bit(
// We get the right repr here because of infinite two's complement.
&format!("{}_b{}", d, i),
x_val.as_ref().map(|x| Integer::from(x.get_bit(i as u32))),
)
})
.collect::<Vec<_>>();
let sum = bits.iter().enumerate().fold(self.r1cs.zero(), |s, (i, b)| {
s + &(b.clone() * &Integer::from(2).pow(i as u32))
});
self.r1cs
.constraint(self.r1cs.zero(), self.r1cs.zero(), sum - x);
.collect::<Vec<_>>()
}
/// Given wire `x`, returns a vector of `n` wires which are the bits of `x`.
/// Constrains `x` to fit in `n` (`signed`) bits.
/// The LSB is at index 0.
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Vec<Lc> {
let bits = self.decomp(d, x, n);
let sum = self.debitify(bits.iter().cloned(), signed);
self.assert_zero(sum - x);
bits
}
/// Given wire `x`, returns whether `x` fits in `n` `signed` bits.
fn fits_in_bits<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Lc {
let bits = self.decomp(d, x, n);
let sum = self.debitify(bits.iter().cloned(), signed);
self.are_equal(sum, x)
}
/// Given a sequence of `bits`, returns a wire which represents their sum,
/// `\sum_{i>0} b_i2^i`.
///
@@ -164,7 +177,7 @@ impl ToR1cs {
fn nary_xor<I: ExactSizeIterator<Item = Lc>>(&mut self, xs: I) -> Lc {
let n = xs.len();
let sum = xs.into_iter().fold(self.r1cs.zero(), |s, i| s + &i);
let sum_bits = self.bitify("sum", &sum, bitsize(n));
let sum_bits = self.bitify("sum", &sum, bitsize(n), false);
assert!(n > 0);
assert!(self.r1cs.modulus() > &n);
sum_bits.into_iter().next().unwrap() // safe b/c assert
@@ -262,46 +275,116 @@ impl ToR1cs {
v
}
Op::Const(Value::Bool(b)) => self.r1cs.zero() + *b as isize,
Op::Eq => self.embed_eq(&c.children[0], &c.children[1]),
Op::Eq => self.embed_eq(&c.cs[0], &c.cs[1]),
Op::Ite => {
let a = self.get_bool(&c.children[0]).clone();
let b = self.get_bool(&c.children[1]).clone();
let c = self.get_bool(&c.children[2]).clone();
let a = self.get_bool(&c.cs[0]).clone();
let b = self.get_bool(&c.cs[1]).clone();
let c = self.get_bool(&c.cs[2]).clone();
self.ite(a, b, &c)
}
Op::Not => {
let a = self.get_bool(&c.children[0]);
let a = self.get_bool(&c.cs[0]);
self.bool_not(a)
}
Op::Implies => {
let a = self.get_bool(&c.children[0]).clone();
let b = self.get_bool(&c.children[1]).clone();
let a = self.get_bool(&c.cs[0]).clone();
let b = self.get_bool(&c.cs[1]).clone();
let not_a = self.bool_not(&a);
self.nary_or(vec![not_a, b].into_iter())
}
Op::BoolNaryOp(o) => {
let args = c
.children
.iter()
.map(|c| self.get_bool(c).clone())
.collect::<Vec<_>>();
let args =
c.cs.iter()
.map(|c| self.get_bool(c).clone())
.collect::<Vec<_>>();
match o {
BoolNaryOp::Or => self.nary_or(args.into_iter()),
BoolNaryOp::And => self.nary_and(args.into_iter()),
BoolNaryOp::Xor => self.nary_xor(args.into_iter()),
}
}
Op::BvBit(i) => {
let a = self.get_bv_bits(&c.cs[0]);
a[*i].clone()
}
Op::BvBinPred(o) => {
let n = check(c.cs[0].clone()).unwrap().as_bv();
use BvBinPred::*;
match o {
Sge => self.bv_cmp(n, true, false, &c.cs[0], &c.cs[1]),
Sgt => self.bv_cmp(n, true, true, &c.cs[0], &c.cs[1]),
Uge => self.bv_cmp(n, false, false, &c.cs[0], &c.cs[1]),
Ugt => self.bv_cmp(n, false, true, &c.cs[0], &c.cs[1]),
Sle => self.bv_cmp(n, true, false, &c.cs[1], &c.cs[0]),
Slt => self.bv_cmp(n, true, true, &c.cs[1], &c.cs[0]),
Ule => self.bv_cmp(n, false, false, &c.cs[1], &c.cs[0]),
Ult => self.bv_cmp(n, false, true, &c.cs[1], &c.cs[0]),
}
}
_ => panic!("Non-boolean in embed_bool: {:?}", c),
};
self.bools.insert(c.clone(), lc);
}
// self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| {
// println!("-> {}", v);
// });
// self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| {
// println!("-> {}", v);
// });
self.bools.get(&c).unwrap()
}
/// Returns whether `a - b` fits in `size` non-negative bits.
/// i.e. is in `{0, 1, ..., 2^n-1}`.
fn bv_ge(&mut self, a: Lc, b: &Lc, size: usize) -> Lc {
self.fits_in_bits("ge", &(a - b), size, false)
}
/// Returns whether `a` is (`strict`ly) (`signed`ly) greater than `b`.
/// Assumes they are each `w`-bit bit-vectors.
fn bv_cmp(&mut self, w: usize, signed: bool, strict: bool, a: &Term, b: &Term) -> Lc {
let a = if signed {
self.get_bv_signed_int(a)
} else {
self.get_bv_uint(a).clone()
};
let b = if signed {
self.get_bv_signed_int(b)
} else {
self.get_bv_uint(b).clone()
};
// Use the fact: a > b <=> a - 1 >= b
self.bv_ge(if strict { a - 1 } else { a }, &b, w)
}
/// Shift `x` left by `2^y`, if bit-valued `c` is true.
fn const_pow_shift_bv(&mut self, x: &Lc, y: usize, c: Lc) -> Lc {
self.ite(c, x.clone() * (1 << y), x)
}
/// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`.
/// Returns an *oversized* number
fn shift_bv(&mut self, x: Lc, y: Vec<Lc>, ext_bit: Option<Lc>) -> Lc {
if let Some(b) = ext_bit {
let left = self.shift_bv(x, y.clone(), None);
let right = self.shift_bv(b.clone(), y, None) - 1;
left + &self.mul(b, right)
} else {
y.into_iter()
.enumerate()
.fold(x, |x, (i, yi)| self.const_pow_shift_bv(&x, i, yi))
}
}
/// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`.
/// Returns a bit sequence.
fn shift_bv_bits(&mut self, x: Lc, y: Vec<Lc>, ext_bit: Option<Lc>, n: usize) -> Vec<Lc> {
let s = self.shift_bv(x, y, ext_bit);
let mut bits = self.bitify("shift", &s, 2 * n - 1, false);
bits.truncate(n);
bits
}
fn embed_bv(&mut self, bv: Term) {
//println!("Embed: {}", bv);
if let Sort::BitVector(n) = check(bv.clone()).unwrap() {
if !self.bvs.contains_key(&bv) {
match &bv.op {
@@ -313,47 +396,53 @@ impl ToR1cs {
self.get_bv_bits(&bv);
}
}
Op::Const(Value::BitVector(BitVector { uint, width })) => {
let bit_lcs = (0..*width)
.map(|i| self.r1cs.zero() + uint.get_bit(i as u32) as isize)
Op::Const(Value::BitVector(b)) => {
let bit_lcs = (0..b.width())
.map(|i| self.r1cs.zero() + b.uint().get_bit(i as u32) as isize)
.collect();
self.set_bv_bits(bv, bit_lcs);
}
Op::Ite => {
let c = self.get_bool(&bv.children[0]).clone();
let t = self.get_bv_uint(&bv.children[1]).clone();
let f = self.get_bv_uint(&bv.children[2]).clone();
let c = self.get_bool(&bv.cs[0]).clone();
let t = self.get_bv_uint(&bv.cs[1]).clone();
let f = self.get_bv_uint(&bv.cs[2]).clone();
let ite = self.ite(c, t, &f);
self.set_bv_uint(bv, ite, n);
}
Op::BvUnOp(BvUnOp::Not) => {
let bits = self.get_bv_bits(&bv.children[0]).clone();
let bits = self.get_bv_bits(&bv.cs[0]).clone();
let not_bits = bits.iter().map(|bit| self.bool_not(bit)).collect();
self.set_bv_bits(bv, not_bits);
}
Op::BvUnOp(BvUnOp::Neg) => {
let x = self.r1cs.zero() + &Integer::from(2).pow(n as u32)
- self.get_bv_uint(&bv.children[0]);
let x = self.get_bv_uint(&bv.cs[0]).clone();
// Wrong for x == 0
let almost_neg_x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - &x;
let is_zero = self.is_zero(x);
let neg_x = self.ite(is_zero, self.r1cs.zero(), &almost_neg_x);
self.set_bv_uint(bv, neg_x, n);
}
Op::BvUext(_) => {
// TODO: carry over bits if possible.
let x = self.get_bv_uint(&bv.cs[0]).clone();
self.set_bv_uint(bv, x, n);
}
Op::BvUext(extra_n) => {
// TODO: carry over bits if possible.
let x = self.get_bv_uint(&bv.children[0]).clone();
let new_n = n + extra_n;
self.set_bv_uint(bv, x, new_n);
}
Op::BvSext(extra_n) => {
let mut bits = self.get_bv_bits(&bv.children[0]).clone().into_iter();
let mut bits = self.get_bv_bits(&bv.cs[0]).clone().into_iter().rev();
let ext_bits =
std::iter::repeat(bits.next().expect("sign ext empty").clone())
.take(extra_n + 1);
self.set_bv_bits(bv, ext_bits.chain(bits).collect());
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
}
Op::BoolToBv => {
let b = self.get_bool(&bv.cs[0]).clone();
self.set_bv_bits(bv, vec![b]);
}
Op::BvNaryOp(o) => match o {
BvNaryOp::Xor | BvNaryOp::Or | BvNaryOp::And => {
let mut bits_by_bv = bv
.children
.cs
.iter()
.map(|c| self.get_bv_bits(c).clone())
.collect::<Vec<_>>();
@@ -363,6 +452,7 @@ impl ToR1cs {
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
);
}
bits_bv_idx.reverse();
let f = |v: Vec<Lc>| match o {
BvNaryOp::And => self.nary_and(v.into_iter()),
BvNaryOp::Or => self.nary_or(v.into_iter()),
@@ -375,7 +465,7 @@ impl ToR1cs {
BvNaryOp::Add | BvNaryOp::Mul => {
let f_width = self.r1cs.modulus().significant_bits() as usize - 1;
let values = bv
.children
.cs
.iter()
.map(|c| self.get_bv_uint(c).clone())
.collect::<Vec<_>>();
@@ -383,21 +473,21 @@ impl ToR1cs {
BvNaryOp::Add => {
let sum =
values.into_iter().fold(self.r1cs.zero(), |s, v| s + &v);
let extra_width = bitsize(bv.children.len()) - 1;
let extra_width = bitsize(bv.cs.len().saturating_sub(1));
(sum, n + extra_width)
}
BvNaryOp::Mul => {
if bv.children.len() * n < f_width {
if bv.cs.len() * n < f_width {
let z = self.r1cs.zero() + 1;
(
values.into_iter().fold(z, |acc, v| self.mul(acc, v)),
bv.children.len() * n,
bv.cs.len() * n,
)
} else {
let z = self.r1cs.zero() + 1;
let p = values.into_iter().fold(z, |acc, v| {
let p = self.mul(acc, v);
let mut bits = self.bitify("binMul", &p, 2 * n);
let mut bits = self.bitify("binMul", &p, 2 * n, false);
bits.truncate(n);
self.debitify(bits.into_iter(), false)
});
@@ -406,17 +496,94 @@ impl ToR1cs {
}
_ => unreachable!(),
};
let mut bits = self.bitify("arith", &res, width);
let mut bits = self.bitify("arith", &res, width, false);
bits.truncate(n);
self.set_bv_bits(bv, bits);
}
},
Op::BvBinOp(o) => unimplemented!(),
_ => panic!("Non-boolean in embed_bool: {:?}", bv),
Op::BvBinOp(o) => {
let a = self.get_bv_uint(&bv.cs[0]);
let b = self.get_bv_uint(&bv.cs[1]);
match o {
BvBinOp::Sub => {
let sum = a.clone() + &(Integer::from(1) << n as u32) - b;
let mut bits = self.bitify("sub", &sum, n + 1, false);
bits.truncate(n);
self.set_bv_bits(bv, bits);
}
BvBinOp::Udiv | BvBinOp::Urem => {
let b = b.clone();
let a = a.clone();
let is_zero = self.is_zero(b.clone());
let (q_v, r_v) = self
.r1cs
.eval(&a)
.and_then(|a| {
self.r1cs.eval(&b).map(|b| {
if b == 0 {
((Integer::from(1) << n as u32) - 1, a)
} else {
(a.clone() / &b, a % b)
}
})
})
.map(|(a, b)| (Some(a), Some(b)))
.unwrap_or((None, None));
let q = self.fresh_var("div_q", q_v);
let r = self.fresh_var("div_q", r_v);
let qb = self.bitify("div_q", &q, n, false);
let rb = self.bitify("div_r", &q, n, false);
self.r1cs.constraint(q.clone(), b.clone(), a - &r);
let is_gt = self.bv_ge(b - 1, &r, n);
let is_not_ge = self.bool_not(&is_gt);
let is_not_zero = self.bool_not(&is_zero);
self.r1cs
.constraint(is_not_ge, is_not_zero, self.r1cs.zero());
let bits = match o {
BvBinOp::Udiv => qb,
BvBinOp::Urem => rb,
_ => unreachable!(),
};
self.set_bv_bits(bv, bits);
}
// Shift cases
_ => {
let r = b.clone();
let a = a.clone();
let b = bitsize(n - 1);
assert!(1 << b == n);
let mut rb = self.get_bv_bits(&bv.cs[1]).clone();
rb.truncate(b);
let sum = self.debitify(rb.clone().into_iter(), false);
self.assert_zero(sum - &r);
let bits = match o {
BvBinOp::Shl => self.shift_bv_bits(a, rb, None, n),
BvBinOp::Lshr | BvBinOp::Ashr => {
let mut lb = self.get_bv_bits(&bv.cs[0]).clone();
lb.reverse();
let ext_bit = match o {
BvBinOp::Ashr => Some(lb.first().unwrap().clone()),
_ => None,
};
let l = self.debitify(lb.into_iter(), false);
let mut bits = self.shift_bv_bits(l, rb, ext_bit, n);
bits.reverse();
bits
}
_ => unreachable!(),
};
self.set_bv_bits(bv, bits);
}
}
}
_ => panic!("Non-bv in embed_bv: {}", bv),
}
}
// self.r1cs.eval(self.get_bv_uint(&bv2)).map(|v| {
// println!("-> {:b}", v);
// });
} else {
panic!("{:?} is not a bit-vector in embed_bv", bv);
panic!("{} is not a bit-vector in embed_bv", bv);
}
}
@@ -455,21 +622,28 @@ impl ToR1cs {
&self.bvs.get(t).expect("Missing term").uint
}
fn get_bv_signed_int(&mut self, t: &Term) -> Lc {
let bits = self.get_bv_bits(t).clone();
self.debitify(bits.into_iter(), true)
}
fn get_bv_bits(&mut self, t: &Term) -> &Vec<Lc> {
let mut bvs = std::mem::take(&mut self.bvs);
let entry = bvs.get_mut(t).expect("Missing bit-vec");
if entry.bits.len() == 0 {
entry.bits = self.bitify("getbits", &entry.uint, entry.width);
entry.bits = self.bitify("getbits", &entry.uint, entry.width, false);
}
self.bvs = bvs;
&self.bvs.get(t).unwrap().bits
}
fn assert_zero(&mut self, x: Lc) {
self.r1cs.constraint(self.r1cs.zero(), self.r1cs.zero(), x);
}
fn assert(&mut self, t: Term) {
self.embed(t.clone());
let lc = self.get_bool(&t).clone();
self.r1cs
.constraint(self.r1cs.zero(), self.r1cs.zero(), lc - 1);
self.assert_zero(lc - 1);
}
}
@@ -494,6 +668,7 @@ fn bitsize(mut n: usize) -> usize {
#[cfg(test)]
mod test {
use super::*;
use crate::ir::term::dist::*;
use quickcheck::{Arbitrary, Gen};
use quickcheck_macros::quickcheck;
use rand::distributions::Distribution;
@@ -526,7 +701,7 @@ mod test {
impl Arbitrary for PureBool {
fn arbitrary(g: &mut Gen) -> Self {
let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g));
let t = BoolDist(g.size()).sample(&mut rng);
let t = PureBoolDist(g.size()).sample(&mut rng);
let values: HashMap<String, Value> = PostOrderIter::new(t.clone())
.filter_map(|c| {
if let Op::Var(n, _) = &c.op {
@@ -553,7 +728,7 @@ mod test {
}
#[quickcheck]
fn random_bool(PureBool(t, values): PureBool) {
fn random_pure_bool(PureBool(t, values): PureBool) {
let t = if eval(&t, &values).as_bool() {
t
} else {
@@ -567,4 +742,76 @@ mod test {
let r1cs = to_r1cs(cs, Integer::from(1014088787));
r1cs.check_all();
}
#[derive(Clone, Debug)]
struct Bool(Term, HashMap<String, Value>);
impl Arbitrary for Bool {
fn arbitrary(g: &mut Gen) -> Self {
let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g));
let d = FixedSizeDist {
bv_width: 8,
size: g.size(),
sort: Sort::Bool,
};
let t = d.sample(&mut rng);
let values: HashMap<String, Value> = PostOrderIter::new(t.clone())
.filter_map(|c| match &c.op {
Op::Var(n, Sort::Bool) => Some((n.clone(), Value::Bool(bool::arbitrary(g)))),
Op::Var(n, Sort::BitVector(w)) => Some((
n.clone(),
Value::BitVector(UniformBitVector(*w).sample(&mut rng)),
)),
_ => None,
})
.collect();
Bool(t, values)
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let vs = self.1.clone();
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
Box::new(
ts.into_iter()
.rev()
.skip(1)
.map(move |t| Bool(t, vs.clone())),
)
}
}
#[quickcheck]
fn random_bool(Bool(t, values): Bool) {
let v = eval(&t, &values);
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs = Constraints {
public_inputs: HashSet::new(),
values: Some(values),
assertions: vec![t],
};
let r1cs = to_r1cs(cs, Integer::from(1014088787));
r1cs.check_all();
}
#[test]
fn eq_test() {
let cs = Constraints {
public_inputs: vec!["a"].into_iter().map(|a| a.to_owned()).collect(),
values: Some(
vec![(
"b".to_owned(),
Value::BitVector(BitVector::new(Integer::from(152), 8)),
)]
.into_iter()
.collect(),
),
assertions: vec![
term![Op::Not; term![Op::Eq; leaf_term(Op::Const(Value::BitVector(BitVector::new(Integer::from(0b10110), 8)))),
term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]],
],
};
let r1cs = to_r1cs(cs, Integer::from(1014088787));
r1cs.check_all();
}
}

View File

@@ -2,5 +2,5 @@
#[macro_use]
pub mod ir;
pub mod util;
pub mod target;
pub mod util;

View File

@@ -1,4 +1,5 @@
use rug::Integer;
use rug::ops::{RemRounding, RemRoundingAssign};
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt::Display;
@@ -34,13 +35,13 @@ impl std::ops::AddAssign<&Lc> for Lc {
fn add_assign(&mut self, other: &Lc) {
assert_eq!(&self.modulus, &other.modulus);
self.constant += &other.constant;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
for (i, v) in &other.monomials {
self.monomials
.entry(*i)
.and_modify(|u| {
*u += v;
*u %= &*other.modulus;
u.rem_floor_assign(&*other.modulus);
})
.or_insert_with(|| v.clone());
}
@@ -58,7 +59,7 @@ impl std::ops::Add<&Integer> for Lc {
impl std::ops::AddAssign<&Integer> for Lc {
fn add_assign(&mut self, other: &Integer) {
self.constant += other;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
}
}
@@ -73,7 +74,7 @@ impl std::ops::Add<isize> for Lc {
impl std::ops::AddAssign<isize> for Lc {
fn add_assign(&mut self, other: isize) {
self.constant += Integer::from(other);
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
}
}
@@ -89,13 +90,13 @@ impl std::ops::SubAssign<&Lc> for Lc {
fn sub_assign(&mut self, other: &Lc) {
assert_eq!(&self.modulus, &other.modulus);
self.constant -= &other.constant;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
for (i, v) in &other.monomials {
self.monomials
.entry(*i)
.and_modify(|u| {
*u -= v;
*u %= &*other.modulus;
u.rem_floor_assign(&*other.modulus);
})
.or_insert_with(|| -v.clone());
}
@@ -113,7 +114,7 @@ impl std::ops::Sub<&Integer> for Lc {
impl std::ops::SubAssign<&Integer> for Lc {
fn sub_assign(&mut self, other: &Integer) {
self.constant -= other;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
}
}
@@ -128,7 +129,7 @@ impl std::ops::Sub<isize> for Lc {
impl std::ops::SubAssign<isize> for Lc {
fn sub_assign(&mut self, other: isize) {
self.constant -= Integer::from(other);
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
}
}
@@ -136,10 +137,10 @@ impl std::ops::Neg for Lc {
type Output = Lc;
fn neg(mut self) -> Lc {
self.constant = -self.constant;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
for (_, v) in &mut self.monomials {
*v *= Integer::from(-1);
*v %= &*self.modulus;
v.rem_floor_assign(&*self.modulus);
}
self
}
@@ -156,10 +157,10 @@ impl std::ops::Mul<&Integer> for Lc {
impl std::ops::MulAssign<&Integer> for Lc {
fn mul_assign(&mut self, other: &Integer) {
self.constant *= other;
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
for (_, v) in &mut self.monomials {
*v *= other;
*v %= &*self.modulus;
v.rem_floor_assign(&*self.modulus);
}
}
}
@@ -175,10 +176,10 @@ impl std::ops::Mul<isize> for Lc {
impl std::ops::MulAssign<isize> for Lc {
fn mul_assign(&mut self, other: isize) {
self.constant *= Integer::from(other);
self.constant %= &*self.modulus;
self.constant.rem_floor_assign(&*self.modulus);
for (_, v) in &mut self.monomials {
*v *= Integer::from(other);
*v %= &*self.modulus;
v.rem_floor_assign(&*self.modulus);
}
}
}
@@ -253,7 +254,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
let sign = |i: &Integer| if i < &half_m { "+" } else { "-" };
let format_i = |i: &Integer| format!("{}{}", sign(i), abs(i));
s.extend(format_i(&a.constant).chars());
s.extend(format_i(&Integer::from(&a.constant)).chars());
for (idx, coeff) in &a.monomials {
s.extend(
format!(
@@ -271,7 +272,8 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
let av = self.eval(a).unwrap();
let bv = self.eval(b).unwrap();
let cv = self.eval(c).unwrap();
if &(av.clone() * &bv % &*self.modulus) != &cv {
dbg!(&av, &bv, &cv, &self.modulus);
if &((av.clone() * &bv).rem_floor(&*self.modulus)) != &cv {
panic!(
"Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})",
self.format_lc(a),
@@ -292,7 +294,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
.expect("Missing value in R1cs::eval")
.clone();
acc += val * coeff;
acc %= &*self.modulus;
acc.rem_floor_assign(&*self.modulus);
}
acc
})

View File

@@ -0,0 +1 @@