mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
Reasonably complete BV->R1CS
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
use circ::ir;
|
||||
use rand::distributions::Distribution;
|
||||
|
||||
|
||||
fn main() {
|
||||
let mut rng = rand::thread_rng();
|
||||
for i in 0..100 {
|
||||
let t = ir::term::BoolDist(6).sample(&mut rng);
|
||||
println!("Term: {:#?}", t)
|
||||
for _ in 0..100 {
|
||||
let d = ir::term::dist::FixedSizeDist {
|
||||
bv_width: 4,
|
||||
sort: ir::term::Sort::Bool,
|
||||
size: 6,
|
||||
};
|
||||
let t = d.sample(&mut rng);
|
||||
println!("Term: {}", t)
|
||||
}
|
||||
}
|
||||
|
||||
208
src/ir/term/bv.rs
Normal file
208
src/ir/term/bv.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use rug::Integer;
|
||||
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct BitVector {
|
||||
uint: Integer,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
macro_rules! bv_arith_impl {
|
||||
($Trait:path, $fn:ident) => {
|
||||
impl $Trait for BitVector {
|
||||
type Output = Self;
|
||||
fn $fn(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
let r = BitVector {
|
||||
uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
};
|
||||
r.check(std::stringify!($fn));
|
||||
r
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bv_arith_impl!(std::ops::Add, add);
|
||||
bv_arith_impl!(std::ops::Sub, sub);
|
||||
bv_arith_impl!(std::ops::Mul, mul);
|
||||
//bv_arith_impl!(std::ops::Div, div);
|
||||
//bv_arith_impl!(std::ops::Rem, rem);
|
||||
bv_arith_impl!(std::ops::BitAnd, bitand);
|
||||
bv_arith_impl!(std::ops::BitOr, bitor);
|
||||
bv_arith_impl!(std::ops::BitXor, bitxor);
|
||||
|
||||
/// SMT-semantics implementation of unsigned division (udiv).
|
||||
///
|
||||
/// If the divisor is zero, returns the all-ones vector.
|
||||
impl std::ops::Div<&BitVector> for BitVector {
|
||||
type Output = Self;
|
||||
fn div(self, other: &Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
if other.uint == 0 {
|
||||
let r = BitVector {
|
||||
uint: (Integer::from(1) << self.width as u32) - 1,
|
||||
width: self.width,
|
||||
};
|
||||
r.check("div");
|
||||
r
|
||||
} else {
|
||||
let r = BitVector {
|
||||
uint: self.uint / &other.uint,
|
||||
width: self.width,
|
||||
};
|
||||
r.check("div");
|
||||
r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SMT-semantics implementation of unsigned remainder (urem).
|
||||
///
|
||||
/// If the divisor is zero, returns the all-ones vector.
|
||||
impl std::ops::Rem<&BitVector> for BitVector {
|
||||
type Output = Self;
|
||||
fn rem(self, other: &Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
if other.uint == 0 {
|
||||
self
|
||||
} else {
|
||||
let r = BitVector {
|
||||
uint: self.uint % &other.uint,
|
||||
width: self.width,
|
||||
};
|
||||
r.check("rem");
|
||||
r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Shl for BitVector {
|
||||
type Output = Self;
|
||||
fn shl(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
let r = BitVector {
|
||||
uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
};
|
||||
r.check("shl");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Neg for BitVector {
|
||||
type Output = Self;
|
||||
fn neg(self) -> Self {
|
||||
let r = BitVector {
|
||||
uint: ((Integer::from(1) << self.width as u32) - self.uint)
|
||||
.keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
};
|
||||
r.check("neg");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Not for BitVector {
|
||||
type Output = Self;
|
||||
fn not(self) -> Self {
|
||||
let r = BitVector {
|
||||
uint: (Integer::from(1) << self.width as u32) - 1 - self.uint,
|
||||
width: self.width,
|
||||
};
|
||||
r.check("not");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
impl BitVector {
|
||||
#[track_caller]
|
||||
#[inline]
|
||||
pub fn check(&self, location: &str) {
|
||||
debug_assert!(
|
||||
self.uint >= 0,
|
||||
"Too small bitvector: {:?}, {}\n at {}",
|
||||
self,
|
||||
self.uint.significant_bits(),
|
||||
location
|
||||
);
|
||||
debug_assert!(
|
||||
(self.uint.significant_bits() as usize) <= self.width,
|
||||
"Too big bitvector: {:?}, {}\n at {}",
|
||||
self,
|
||||
self.uint.significant_bits(),
|
||||
location
|
||||
);
|
||||
}
|
||||
pub fn ashr(mut self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
let n = other.uint.to_u32().unwrap();
|
||||
let b = self.uint.get_bit(self.width as u32 - 1);
|
||||
self.uint >>= n;
|
||||
for i in 0..n {
|
||||
self.uint.set_bit(self.width as u32 - 1 - i, b);
|
||||
}
|
||||
self.check("ashr");
|
||||
self
|
||||
}
|
||||
pub fn lshr(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
let r = BitVector {
|
||||
uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
};
|
||||
r.check("lshr");
|
||||
r
|
||||
}
|
||||
pub fn concat(self, other: Self) -> Self {
|
||||
let r = BitVector {
|
||||
uint: (self.uint << other.width as u32) | other.uint,
|
||||
width: self.width + other.width,
|
||||
};
|
||||
r.check("concat");
|
||||
r
|
||||
}
|
||||
pub fn extract(self, high: usize, low: usize) -> Self {
|
||||
let r = BitVector {
|
||||
uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32),
|
||||
width: high - low + 1,
|
||||
};
|
||||
r.check("extract");
|
||||
r
|
||||
}
|
||||
pub fn as_sint(&self) -> Integer {
|
||||
if self.uint.significant_bits() as usize == self.width {
|
||||
self.uint.clone() - (Integer::from(1) << self.width as u32)
|
||||
} else {
|
||||
self.uint.clone()
|
||||
}
|
||||
}
|
||||
pub fn uint(&self) -> &Integer {
|
||||
&self.uint
|
||||
}
|
||||
pub fn width(&self) -> usize {
|
||||
self.width
|
||||
}
|
||||
#[track_caller]
|
||||
pub fn new(uint: Integer, width: usize) -> BitVector {
|
||||
let r = BitVector { uint, width };
|
||||
r.check("new");
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for BitVector {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(f, "#b")?;
|
||||
for i in 0..self.width {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
self.uint.get_bit((self.width - i - 1) as u32) as u8
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
207
src/ir/term/dist.rs
Normal file
207
src/ir/term/dist.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use super::*;
|
||||
|
||||
// A distribution of boolean terms with some size.
|
||||
// All subterms are booleans.
|
||||
pub struct PureBoolDist(pub usize);
|
||||
|
||||
// A distribution of n usizes that sum to this value.
|
||||
// (n, sum)
|
||||
pub struct Sum(usize, usize);
|
||||
impl rand::distributions::Distribution<Vec<usize>> for Sum {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<usize> {
|
||||
use rand::seq::SliceRandom;
|
||||
let mut acc = self.1;
|
||||
let mut ns = Vec::new();
|
||||
assert!(acc == 0 || self.0 > 0);
|
||||
while acc > 0 && ns.len() < self.0 {
|
||||
let x = rng.gen_range(0..acc);
|
||||
acc -= x;
|
||||
ns.push(x);
|
||||
}
|
||||
while ns.len() < self.0 {
|
||||
ns.push(0);
|
||||
}
|
||||
if acc > 0 {
|
||||
*ns.last_mut().unwrap() += acc;
|
||||
}
|
||||
ns.shuffle(rng);
|
||||
ns
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<Term> for PureBoolDist {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
|
||||
use rand::seq::SliceRandom;
|
||||
let ops = &[
|
||||
Op::Const(Value::Bool(rng.gen())),
|
||||
Op::Var(
|
||||
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
|
||||
.unwrap()
|
||||
.to_owned(),
|
||||
Sort::Bool,
|
||||
),
|
||||
Op::Not,
|
||||
Op::Implies,
|
||||
Op::BoolNaryOp(BoolNaryOp::Or),
|
||||
Op::BoolNaryOp(BoolNaryOp::And),
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor),
|
||||
];
|
||||
let o = match self.0 {
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
2 => ops[2..3].choose(rng), // arity 1
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}
|
||||
.unwrap()
|
||||
.clone();
|
||||
// Now, self.0 is a least arity+1
|
||||
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
|
||||
let excess = self.0 - 1 - a;
|
||||
let ns = Sum(a, excess).sample(rng);
|
||||
let subterms = ns
|
||||
.into_iter()
|
||||
.map(|n| PureBoolDist(n + 1).sample(rng))
|
||||
.collect::<Vec<_>>();
|
||||
term(o, subterms)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FixedSizeDist {
|
||||
pub size: usize,
|
||||
pub bv_width: usize,
|
||||
pub sort: Sort,
|
||||
}
|
||||
|
||||
impl FixedSizeDist {
|
||||
fn with_size(&self, size: usize) -> Self {
|
||||
FixedSizeDist {
|
||||
size,
|
||||
sort: self.sort.clone(),
|
||||
bv_width: self.bv_width,
|
||||
}
|
||||
}
|
||||
fn with_sort(&self, sort: Sort) -> Self {
|
||||
FixedSizeDist {
|
||||
size: self.size,
|
||||
sort,
|
||||
bv_width: self.bv_width,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UniformBitVector(pub usize);
|
||||
|
||||
impl rand::distributions::Distribution<BitVector> for UniformBitVector {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BitVector {
|
||||
let mut rug_rng = rug::rand::RandState::new_mersenne_twister();
|
||||
rug_rng.seed(&Integer::from(rng.next_u32()));
|
||||
BitVector::new(
|
||||
Integer::from(Integer::random_bits(self.0 as u32, &mut rug_rng)),
|
||||
self.0,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<Term> for FixedSizeDist {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
|
||||
use rand::seq::SliceRandom;
|
||||
match self.sort.clone() {
|
||||
Sort::Bool => {
|
||||
let ops = &[
|
||||
Op::Const(Value::Bool(rng.gen())),
|
||||
Op::Var(
|
||||
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
|
||||
.unwrap()
|
||||
.to_owned(),
|
||||
Sort::Bool,
|
||||
),
|
||||
Op::Not, // 2
|
||||
Op::Implies,
|
||||
Op::Eq,
|
||||
Op::BvBinPred(BvBinPred::Sge),
|
||||
Op::BvBinPred(BvBinPred::Sgt),
|
||||
Op::BvBinPred(BvBinPred::Sle),
|
||||
Op::BvBinPred(BvBinPred::Slt),
|
||||
Op::BvBinPred(BvBinPred::Uge),
|
||||
Op::BvBinPred(BvBinPred::Ugt),
|
||||
Op::BvBinPred(BvBinPred::Ule),
|
||||
Op::BvBinPred(BvBinPred::Ult),
|
||||
Op::BoolNaryOp(BoolNaryOp::Or),
|
||||
Op::BoolNaryOp(BoolNaryOp::And),
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor),
|
||||
Op::Ite,
|
||||
];
|
||||
let o = match self.size {
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
2 => ops[2..3].choose(rng), // arity 1
|
||||
3 => ops[2..16].choose(rng), // arity 2
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}
|
||||
.unwrap()
|
||||
.clone();
|
||||
// Now, self.0 is a least arity+1
|
||||
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size));
|
||||
let excess = self.size - 1 - a;
|
||||
let ns = Sum(a, excess).sample(rng);
|
||||
let sort = match o {
|
||||
Op::Eq => [Sort::Bool, Sort::BitVector(self.bv_width)]
|
||||
.choose(rng)
|
||||
.unwrap()
|
||||
.clone(),
|
||||
Op::BvBinPred(_) => Sort::BitVector(self.bv_width),
|
||||
_ => Sort::Bool,
|
||||
};
|
||||
let subterms = ns
|
||||
.into_iter()
|
||||
.map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng))
|
||||
.collect::<Vec<_>>();
|
||||
term(o, subterms)
|
||||
}
|
||||
Sort::BitVector(w) => {
|
||||
let ops = &[
|
||||
Op::Const(Value::BitVector(UniformBitVector(w).sample(rng))),
|
||||
Op::Var(
|
||||
format!(
|
||||
"{}_bv{}",
|
||||
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]).unwrap(),
|
||||
w
|
||||
),
|
||||
Sort::BitVector(w),
|
||||
),
|
||||
Op::BvUnOp(BvUnOp::Neg),
|
||||
Op::BvUnOp(BvUnOp::Not),
|
||||
Op::BvUext(rng.gen_range(0..w)),
|
||||
Op::BvSext(rng.gen_range(0..w)),
|
||||
Op::BvBinOp(BvBinOp::Sub),
|
||||
Op::BvNaryOp(BvNaryOp::Or),
|
||||
Op::BvNaryOp(BvNaryOp::And),
|
||||
Op::BvNaryOp(BvNaryOp::Xor),
|
||||
Op::BvNaryOp(BvNaryOp::Add),
|
||||
Op::BvNaryOp(BvNaryOp::Mul),
|
||||
// Add ITEs
|
||||
];
|
||||
let o = match self.size {
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
2 => ops[2..4].choose(rng), // arity 1
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}
|
||||
.unwrap()
|
||||
.clone();
|
||||
let sort = match o {
|
||||
Op::BvUext(ww) => Sort::BitVector(w - ww),
|
||||
Op::BvSext(ww) => Sort::BitVector(w - ww),
|
||||
_ => Sort::BitVector(w),
|
||||
};
|
||||
// Now, self.0 is a least arity+1
|
||||
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.size));
|
||||
let excess = self.size - 1 - a;
|
||||
let ns = Sum(a, excess).sample(rng);
|
||||
let subterms = ns
|
||||
.into_iter()
|
||||
.map(|n| self.with_size(n + 1).with_sort(sort.clone()).sample(rng))
|
||||
.collect::<Vec<_>>();
|
||||
term(o, subterms)
|
||||
}
|
||||
s => panic!("Unsampleabl sort: {}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,11 @@ use std::collections::HashSet;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
pub mod bv;
|
||||
pub mod dist;
|
||||
|
||||
pub use bv::BitVector;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum Op {
|
||||
Ite,
|
||||
@@ -352,16 +357,16 @@ impl Display for PfUnOp {
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct TermData {
|
||||
pub op: Op,
|
||||
pub children: Vec<Term>,
|
||||
pub cs: Vec<Term>,
|
||||
}
|
||||
|
||||
impl Display for TermData {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
if let Op::Var(..) = &self.op {
|
||||
if self.op.arity() == Some(0) {
|
||||
write!(f, "{}", self.op)
|
||||
} else {
|
||||
write!(f, "({}", self.op)?;
|
||||
for c in &self.children {
|
||||
for c in &self.cs {
|
||||
write!(f, " {}", c)?;
|
||||
}
|
||||
write!(f, ")")
|
||||
@@ -385,90 +390,6 @@ impl Display for FieldElem {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct BitVector {
|
||||
pub uint: Integer,
|
||||
pub width: usize,
|
||||
}
|
||||
|
||||
macro_rules! bv_arith_impl {
|
||||
($Trait:path, $fn:ident) => {
|
||||
impl $Trait for BitVector {
|
||||
type Output = Self;
|
||||
fn $fn(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
BitVector {
|
||||
uint: (self.uint.$fn(other.uint)).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bv_arith_impl!(std::ops::Add, add);
|
||||
bv_arith_impl!(std::ops::Sub, sub);
|
||||
bv_arith_impl!(std::ops::Mul, mul);
|
||||
bv_arith_impl!(std::ops::Div, div);
|
||||
bv_arith_impl!(std::ops::Rem, rem);
|
||||
|
||||
impl std::ops::Shl for BitVector {
|
||||
type Output = Self;
|
||||
fn shl(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
BitVector {
|
||||
uint: (self.uint.shl(other.uint.to_u32().unwrap())).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BitVector {
|
||||
pub fn ashr(mut self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
let n = other.uint.to_u32().unwrap();
|
||||
let b = self.uint.get_bit(self.width as u32 - 1);
|
||||
self.uint >>= n;
|
||||
for i in 0..n {
|
||||
self.uint.set_bit(self.width as u32 - 1 - i, b);
|
||||
}
|
||||
self
|
||||
}
|
||||
pub fn lshr(self, other: Self) -> Self {
|
||||
assert_eq!(self.width, other.width);
|
||||
BitVector {
|
||||
uint: (self.uint >> other.uint.to_u32().unwrap()).keep_bits(self.width as u32),
|
||||
width: self.width,
|
||||
}
|
||||
}
|
||||
pub fn concat(self, other: Self) -> Self {
|
||||
BitVector {
|
||||
uint: (self.uint << other.width as u32) | other.uint,
|
||||
width: self.width + other.width,
|
||||
}
|
||||
}
|
||||
pub fn extract(self, high: usize, low: usize) -> Self {
|
||||
BitVector {
|
||||
uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32),
|
||||
width: high - low + 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for BitVector {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(f, "#b")?;
|
||||
for i in 0..self.width {
|
||||
write!(
|
||||
f,
|
||||
"#{}",
|
||||
self.uint.get_bit((self.width - i - 1) as u32) as u8
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Value {
|
||||
BitVector(BitVector),
|
||||
@@ -517,6 +438,30 @@ pub enum Sort {
|
||||
Array(Box<Sort>, Box<Sort>),
|
||||
}
|
||||
|
||||
impl Sort {
|
||||
pub fn as_bv(&self) -> usize {
|
||||
if let Sort::BitVector(w) = self {
|
||||
*w
|
||||
} else {
|
||||
panic!("{} is not a bit-vector", self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Sort {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Sort::Bool => write!(f, "bool"),
|
||||
Sort::BitVector(n) => write!(f, "(bv {})", n),
|
||||
Sort::Int => write!(f, "int"),
|
||||
Sort::F32 => write!(f, "f32"),
|
||||
Sort::F64 => write!(f, "f64"),
|
||||
Sort::Field(i) => write!(f, "(mod {})", i),
|
||||
Sort::Array(k, v) => write!(f, "(array {} {})", k, v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type Term = HConsed<TermData>;
|
||||
// "Temporary" terms.
|
||||
pub type TTerm = WHConsed<TermData>;
|
||||
@@ -537,9 +482,10 @@ impl Value {
|
||||
Value::Int(_) => Sort::Int,
|
||||
Value::F64(_) => Sort::F64,
|
||||
Value::F32(_) => Sort::F32,
|
||||
Value::BitVector(b) => Sort::BitVector(b.width),
|
||||
Value::BitVector(b) => Sort::BitVector(b.width()),
|
||||
}
|
||||
}
|
||||
#[track_caller]
|
||||
pub fn as_bool(&self) -> bool {
|
||||
if let Value::Bool(b) = self {
|
||||
*b
|
||||
@@ -547,6 +493,7 @@ impl Value {
|
||||
panic!("Not a bool: {}", self)
|
||||
}
|
||||
}
|
||||
#[track_caller]
|
||||
pub fn as_bv(&self) -> &BitVector {
|
||||
if let Value::BitVector(b) = self {
|
||||
b
|
||||
@@ -621,7 +568,7 @@ pub fn check(t: Term) -> Result<Sort, TypeError> {
|
||||
}
|
||||
{
|
||||
let mut term_tys = TERM_TYPES.write().unwrap();
|
||||
// to_check is a stack of (node, children checked) pairs.
|
||||
// to_check is a stack of (node, cs checked) pairs.
|
||||
let mut to_check = vec![(t.clone(), false)];
|
||||
while to_check.len() > 0 {
|
||||
let back = to_check.last_mut().unwrap();
|
||||
@@ -640,13 +587,13 @@ pub fn check(t: Term) -> Result<Sort, TypeError> {
|
||||
}
|
||||
if !back.1 {
|
||||
back.1 = true;
|
||||
for c in back.0.children.clone() {
|
||||
for c in back.0.cs.clone() {
|
||||
to_check.push((c, false));
|
||||
}
|
||||
} else {
|
||||
let tys = back
|
||||
.0
|
||||
.children
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| term_tys.get(&c.to_weak()).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
@@ -781,51 +728,101 @@ pub fn eval(t: &Term, h: &HashMap<String, Value>) -> Value {
|
||||
for c in PostOrderIter::new(t.clone()) {
|
||||
let v = match &c.op {
|
||||
Op::Var(n, _) => h.get(n).unwrap().clone(),
|
||||
Op::Eq => {
|
||||
Value::Bool(vs.get(&c.children[0]).unwrap() == vs.get(&c.children[1]).unwrap())
|
||||
}
|
||||
Op::Not => Value::Bool(!vs.get(&c.children[0]).unwrap().as_bool()),
|
||||
Op::Eq => Value::Bool(vs.get(&c.cs[0]).unwrap() == vs.get(&c.cs[1]).unwrap()),
|
||||
Op::Not => Value::Bool(!vs.get(&c.cs[0]).unwrap().as_bool()),
|
||||
Op::Implies => Value::Bool(
|
||||
!vs.get(&c.children[0]).unwrap().as_bool()
|
||||
|| vs.get(&c.children[1]).unwrap().as_bool(),
|
||||
!vs.get(&c.cs[0]).unwrap().as_bool() || vs.get(&c.cs[1]).unwrap().as_bool(),
|
||||
),
|
||||
Op::BoolNaryOp(BoolNaryOp::Or) => {
|
||||
Value::Bool(c.children.iter().any(|c| vs.get(c).unwrap().as_bool()))
|
||||
Value::Bool(c.cs.iter().any(|c| vs.get(c).unwrap().as_bool()))
|
||||
}
|
||||
Op::BoolNaryOp(BoolNaryOp::And) => {
|
||||
Value::Bool(c.children.iter().all(|c| vs.get(c).unwrap().as_bool()))
|
||||
Value::Bool(c.cs.iter().all(|c| vs.get(c).unwrap().as_bool()))
|
||||
}
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor) => Value::Bool(
|
||||
c.children
|
||||
.iter()
|
||||
c.cs.iter()
|
||||
.map(|c| vs.get(c).unwrap().as_bool())
|
||||
.fold(false, std::ops::BitXor::bitxor),
|
||||
),
|
||||
Op::BvBit(i) => Value::Bool(
|
||||
vs.get(&c.children[0])
|
||||
.unwrap()
|
||||
.as_bv()
|
||||
.uint
|
||||
.get_bit(*i as u32),
|
||||
),
|
||||
Op::BvBit(i) => {
|
||||
Value::Bool(vs.get(&c.cs[0]).unwrap().as_bv().uint().get_bit(*i as u32))
|
||||
}
|
||||
Op::BvConcat => Value::BitVector({
|
||||
let mut it = c
|
||||
.children
|
||||
.iter()
|
||||
.map(|c| vs.get(c).unwrap().as_bv().clone());
|
||||
let mut it = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone());
|
||||
let f = it.next().unwrap();
|
||||
it.fold(f, BitVector::concat)
|
||||
}),
|
||||
Op::BvExtract(h, l) => Value::BitVector(
|
||||
vs.get(&c.children[0])
|
||||
.unwrap()
|
||||
.as_bv()
|
||||
.clone()
|
||||
.extract(*h, *l),
|
||||
),
|
||||
Op::BvExtract(h, l) => {
|
||||
Value::BitVector(vs.get(&c.cs[0]).unwrap().as_bv().clone().extract(*h, *l))
|
||||
}
|
||||
Op::Const(v) => v.clone(),
|
||||
Op::BvBinOp(o) => Value::BitVector({
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
|
||||
let b = vs.get(&c.cs[1]).unwrap().as_bv().clone();
|
||||
match o {
|
||||
BvBinOp::Udiv => a / &b,
|
||||
BvBinOp::Urem => a % &b,
|
||||
BvBinOp::Sub => a - b,
|
||||
BvBinOp::Ashr => a.ashr(b),
|
||||
BvBinOp::Lshr => a.lshr(b),
|
||||
BvBinOp::Shl => a << b,
|
||||
}
|
||||
}),
|
||||
Op::BvUnOp(o) => Value::BitVector({
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
|
||||
match o {
|
||||
BvUnOp::Not => !a,
|
||||
BvUnOp::Neg => -a,
|
||||
}
|
||||
}),
|
||||
Op::BvNaryOp(o) => Value::BitVector({
|
||||
let mut xs = c.cs.iter().map(|c| vs.get(c).unwrap().as_bv().clone());
|
||||
let f = xs.next().unwrap();
|
||||
xs.fold(
|
||||
f,
|
||||
match o {
|
||||
BvNaryOp::Add => std::ops::Add::add,
|
||||
BvNaryOp::Mul => std::ops::Mul::mul,
|
||||
BvNaryOp::Xor => std::ops::BitXor::bitxor,
|
||||
BvNaryOp::Or => std::ops::BitOr::bitor,
|
||||
BvNaryOp::And => std::ops::BitAnd::bitand,
|
||||
},
|
||||
)
|
||||
}),
|
||||
Op::BvSext(w) => Value::BitVector({
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
|
||||
let mask = ((Integer::from(1) << *w as u32) - 1)
|
||||
* Integer::from(a.uint().get_bit(a.width() as u32 - 1));
|
||||
BitVector::new(a.uint() | (mask << a.width() as u32), a.width() + w)
|
||||
}),
|
||||
Op::BvUext(w) => Value::BitVector({
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
|
||||
BitVector::new(a.uint().clone(), a.width() + w)
|
||||
}),
|
||||
Op::Ite => if vs.get(&c.cs[0]).unwrap().as_bool() {
|
||||
vs.get(&c.cs[1])
|
||||
} else {
|
||||
vs.get(&c.cs[2])
|
||||
}
|
||||
.unwrap()
|
||||
.clone(),
|
||||
Op::BvBinPred(o) => Value::Bool({
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv();
|
||||
let b = vs.get(&c.cs[1]).unwrap().as_bv();
|
||||
match o {
|
||||
BvBinPred::Sge => a.as_sint() >= b.as_sint(),
|
||||
BvBinPred::Sgt => a.as_sint() > b.as_sint(),
|
||||
BvBinPred::Sle => a.as_sint() <= b.as_sint(),
|
||||
BvBinPred::Slt => a.as_sint() < b.as_sint(),
|
||||
BvBinPred::Uge => a.uint() >= b.uint(),
|
||||
BvBinPred::Ugt => a.uint() > b.uint(),
|
||||
BvBinPred::Ule => a.uint() <= b.uint(),
|
||||
BvBinPred::Ult => a.uint() < b.uint(),
|
||||
}
|
||||
}),
|
||||
o => unimplemented!("eval: {:?}", o),
|
||||
};
|
||||
//println!("Eval {}\nAs {}", c, v);
|
||||
vs.insert(c.clone(), v);
|
||||
}
|
||||
vs.get(t).unwrap().clone()
|
||||
@@ -854,9 +851,9 @@ pub fn leaf_term(op: Op) -> Term {
|
||||
term(op, Vec::new())
|
||||
}
|
||||
|
||||
pub fn term(op: Op, children: Vec<Term>) -> Term {
|
||||
pub fn term(op: Op, cs: Vec<Term>) -> Term {
|
||||
use hashconsing::HashConsign;
|
||||
let t = TERM_FACTORY.mk(TermData { op, children });
|
||||
let t = TERM_FACTORY.mk(TermData { op, cs });
|
||||
check(t.clone()).unwrap();
|
||||
t
|
||||
}
|
||||
@@ -868,76 +865,11 @@ macro_rules! term {
|
||||
};
|
||||
}
|
||||
|
||||
// A distribution of boolean terms with some size.
|
||||
// All subterms are booleans.
|
||||
pub struct BoolDist(pub usize);
|
||||
|
||||
// A distribution of n usizes that sum to this value.
|
||||
// (n, sum)
|
||||
pub struct Sum(usize, usize);
|
||||
impl rand::distributions::Distribution<Vec<usize>> for Sum {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Vec<usize> {
|
||||
use rand::seq::SliceRandom;
|
||||
let mut acc = self.1;
|
||||
let mut ns = Vec::new();
|
||||
assert!(acc == 0 || self.0 > 0);
|
||||
while acc > 0 && ns.len() < self.0 {
|
||||
let x = rng.gen_range(0..acc);
|
||||
acc -= x;
|
||||
ns.push(x);
|
||||
}
|
||||
while ns.len() < self.0 {
|
||||
ns.push(0);
|
||||
}
|
||||
if acc > 0 {
|
||||
*ns.last_mut().unwrap() += acc;
|
||||
}
|
||||
ns.shuffle(rng);
|
||||
ns
|
||||
}
|
||||
}
|
||||
|
||||
impl rand::distributions::Distribution<Term> for BoolDist {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term {
|
||||
use rand::seq::SliceRandom;
|
||||
let ops = &[
|
||||
Op::Const(Value::Bool(rng.gen())),
|
||||
Op::Var(
|
||||
std::str::from_utf8(&[b'a' + rng.gen_range(0..26)])
|
||||
.unwrap()
|
||||
.to_owned(),
|
||||
Sort::Bool,
|
||||
),
|
||||
Op::Not,
|
||||
Op::Implies,
|
||||
Op::BoolNaryOp(BoolNaryOp::Or),
|
||||
Op::BoolNaryOp(BoolNaryOp::And),
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor),
|
||||
];
|
||||
let o = match self.0 {
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
2 => ops[2..3].choose(rng), // arity 1
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}
|
||||
.unwrap()
|
||||
.clone();
|
||||
// Now, self.0 is a least arity+1
|
||||
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
|
||||
let excess = self.0 - 1 - a;
|
||||
let ns = Sum(a, excess).sample(rng);
|
||||
let subterms = ns
|
||||
.into_iter()
|
||||
.map(|n| BoolDist(n + 1).sample(rng))
|
||||
.collect::<Vec<_>>();
|
||||
term(o, subterms)
|
||||
}
|
||||
}
|
||||
|
||||
pub type TermMap<T> = hashconsing::coll::HConMap<Term, T>;
|
||||
pub type TermSet = hashconsing::coll::HConSet<Term>;
|
||||
|
||||
pub struct PostOrderIter {
|
||||
// (children stacked, term)
|
||||
// (cs stacked, term)
|
||||
stack: Vec<(bool, Term)>,
|
||||
visited: TermSet,
|
||||
}
|
||||
@@ -959,7 +891,7 @@ impl std::iter::Iterator for PostOrderIter {
|
||||
self.stack.pop();
|
||||
} else if !children_pushed {
|
||||
self.stack.last_mut().unwrap().0 = true;
|
||||
let cs = self.stack.last().unwrap().1.children.clone();
|
||||
let cs = self.stack.last().unwrap().1.cs.clone();
|
||||
self.stack.extend(cs.into_iter().map(|c| (false, c)));
|
||||
} else {
|
||||
break;
|
||||
@@ -118,32 +118,45 @@ impl ToR1cs {
|
||||
self.values
|
||||
.as_ref()
|
||||
.map(|vs| match vs.get(var).expect("missing value") {
|
||||
Value::BitVector(BitVector { uint, .. }) => uint.clone(),
|
||||
Value::BitVector(b) => b.uint().clone(),
|
||||
v => panic!("{} should be a bool, but is {:?}", var, v),
|
||||
})
|
||||
}
|
||||
|
||||
/// Given wire `x`, returns a vector of `n` wires which are the bits of `x`.
|
||||
/// Constrains `x` to fit in `n` (unsigned) bits.
|
||||
/// They *have not* been constrained to sum to `x`.
|
||||
/// They have values according the the (infinite) two's complement representation of `x`.
|
||||
/// The LSB is at index 0.
|
||||
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize) -> Vec<Lc> {
|
||||
fn decomp<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize) -> Vec<Lc> {
|
||||
let x_val = self.r1cs.eval(x);
|
||||
let bits = (0..n)
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
self.fresh_bit(
|
||||
// We get the right repr here because of infinite two's complement.
|
||||
&format!("{}_b{}", d, i),
|
||||
x_val.as_ref().map(|x| Integer::from(x.get_bit(i as u32))),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let sum = bits.iter().enumerate().fold(self.r1cs.zero(), |s, (i, b)| {
|
||||
s + &(b.clone() * &Integer::from(2).pow(i as u32))
|
||||
});
|
||||
self.r1cs
|
||||
.constraint(self.r1cs.zero(), self.r1cs.zero(), sum - x);
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Given wire `x`, returns a vector of `n` wires which are the bits of `x`.
|
||||
/// Constrains `x` to fit in `n` (`signed`) bits.
|
||||
/// The LSB is at index 0.
|
||||
fn bitify<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Vec<Lc> {
|
||||
let bits = self.decomp(d, x, n);
|
||||
let sum = self.debitify(bits.iter().cloned(), signed);
|
||||
self.assert_zero(sum - x);
|
||||
bits
|
||||
}
|
||||
|
||||
/// Given wire `x`, returns whether `x` fits in `n` `signed` bits.
|
||||
fn fits_in_bits<D: Display + ?Sized>(&mut self, d: &D, x: &Lc, n: usize, signed: bool) -> Lc {
|
||||
let bits = self.decomp(d, x, n);
|
||||
let sum = self.debitify(bits.iter().cloned(), signed);
|
||||
self.are_equal(sum, x)
|
||||
}
|
||||
|
||||
/// Given a sequence of `bits`, returns a wire which represents their sum,
|
||||
/// `\sum_{i>0} b_i2^i`.
|
||||
///
|
||||
@@ -164,7 +177,7 @@ impl ToR1cs {
|
||||
fn nary_xor<I: ExactSizeIterator<Item = Lc>>(&mut self, xs: I) -> Lc {
|
||||
let n = xs.len();
|
||||
let sum = xs.into_iter().fold(self.r1cs.zero(), |s, i| s + &i);
|
||||
let sum_bits = self.bitify("sum", &sum, bitsize(n));
|
||||
let sum_bits = self.bitify("sum", &sum, bitsize(n), false);
|
||||
assert!(n > 0);
|
||||
assert!(self.r1cs.modulus() > &n);
|
||||
sum_bits.into_iter().next().unwrap() // safe b/c assert
|
||||
@@ -262,46 +275,116 @@ impl ToR1cs {
|
||||
v
|
||||
}
|
||||
Op::Const(Value::Bool(b)) => self.r1cs.zero() + *b as isize,
|
||||
Op::Eq => self.embed_eq(&c.children[0], &c.children[1]),
|
||||
Op::Eq => self.embed_eq(&c.cs[0], &c.cs[1]),
|
||||
Op::Ite => {
|
||||
let a = self.get_bool(&c.children[0]).clone();
|
||||
let b = self.get_bool(&c.children[1]).clone();
|
||||
let c = self.get_bool(&c.children[2]).clone();
|
||||
let a = self.get_bool(&c.cs[0]).clone();
|
||||
let b = self.get_bool(&c.cs[1]).clone();
|
||||
let c = self.get_bool(&c.cs[2]).clone();
|
||||
self.ite(a, b, &c)
|
||||
}
|
||||
Op::Not => {
|
||||
let a = self.get_bool(&c.children[0]);
|
||||
let a = self.get_bool(&c.cs[0]);
|
||||
self.bool_not(a)
|
||||
}
|
||||
Op::Implies => {
|
||||
let a = self.get_bool(&c.children[0]).clone();
|
||||
let b = self.get_bool(&c.children[1]).clone();
|
||||
let a = self.get_bool(&c.cs[0]).clone();
|
||||
let b = self.get_bool(&c.cs[1]).clone();
|
||||
let not_a = self.bool_not(&a);
|
||||
self.nary_or(vec![not_a, b].into_iter())
|
||||
}
|
||||
Op::BoolNaryOp(o) => {
|
||||
let args = c
|
||||
.children
|
||||
.iter()
|
||||
.map(|c| self.get_bool(c).clone())
|
||||
.collect::<Vec<_>>();
|
||||
let args =
|
||||
c.cs.iter()
|
||||
.map(|c| self.get_bool(c).clone())
|
||||
.collect::<Vec<_>>();
|
||||
match o {
|
||||
BoolNaryOp::Or => self.nary_or(args.into_iter()),
|
||||
BoolNaryOp::And => self.nary_and(args.into_iter()),
|
||||
BoolNaryOp::Xor => self.nary_xor(args.into_iter()),
|
||||
}
|
||||
}
|
||||
Op::BvBit(i) => {
|
||||
let a = self.get_bv_bits(&c.cs[0]);
|
||||
a[*i].clone()
|
||||
}
|
||||
Op::BvBinPred(o) => {
|
||||
let n = check(c.cs[0].clone()).unwrap().as_bv();
|
||||
use BvBinPred::*;
|
||||
match o {
|
||||
Sge => self.bv_cmp(n, true, false, &c.cs[0], &c.cs[1]),
|
||||
Sgt => self.bv_cmp(n, true, true, &c.cs[0], &c.cs[1]),
|
||||
Uge => self.bv_cmp(n, false, false, &c.cs[0], &c.cs[1]),
|
||||
Ugt => self.bv_cmp(n, false, true, &c.cs[0], &c.cs[1]),
|
||||
Sle => self.bv_cmp(n, true, false, &c.cs[1], &c.cs[0]),
|
||||
Slt => self.bv_cmp(n, true, true, &c.cs[1], &c.cs[0]),
|
||||
Ule => self.bv_cmp(n, false, false, &c.cs[1], &c.cs[0]),
|
||||
Ult => self.bv_cmp(n, false, true, &c.cs[1], &c.cs[0]),
|
||||
}
|
||||
}
|
||||
|
||||
_ => panic!("Non-boolean in embed_bool: {:?}", c),
|
||||
};
|
||||
self.bools.insert(c.clone(), lc);
|
||||
}
|
||||
// self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| {
|
||||
// println!("-> {}", v);
|
||||
// });
|
||||
// self.r1cs.eval(self.bools.get(&c).unwrap()).map(|v| {
|
||||
// println!("-> {}", v);
|
||||
// });
|
||||
self.bools.get(&c).unwrap()
|
||||
}
|
||||
|
||||
/// Returns whether `a - b` fits in `size` non-negative bits.
|
||||
/// i.e. is in `{0, 1, ..., 2^n-1}`.
|
||||
fn bv_ge(&mut self, a: Lc, b: &Lc, size: usize) -> Lc {
|
||||
self.fits_in_bits("ge", &(a - b), size, false)
|
||||
}
|
||||
|
||||
/// Returns whether `a` is (`strict`ly) (`signed`ly) greater than `b`.
|
||||
/// Assumes they are each `w`-bit bit-vectors.
|
||||
fn bv_cmp(&mut self, w: usize, signed: bool, strict: bool, a: &Term, b: &Term) -> Lc {
|
||||
let a = if signed {
|
||||
self.get_bv_signed_int(a)
|
||||
} else {
|
||||
self.get_bv_uint(a).clone()
|
||||
};
|
||||
let b = if signed {
|
||||
self.get_bv_signed_int(b)
|
||||
} else {
|
||||
self.get_bv_uint(b).clone()
|
||||
};
|
||||
// Use the fact: a > b <=> a - 1 >= b
|
||||
self.bv_ge(if strict { a - 1 } else { a }, &b, w)
|
||||
}
|
||||
|
||||
/// Shift `x` left by `2^y`, if bit-valued `c` is true.
|
||||
fn const_pow_shift_bv(&mut self, x: &Lc, y: usize, c: Lc) -> Lc {
|
||||
self.ite(c, x.clone() * (1 << y), x)
|
||||
}
|
||||
|
||||
/// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`.
|
||||
/// Returns an *oversized* number
|
||||
fn shift_bv(&mut self, x: Lc, y: Vec<Lc>, ext_bit: Option<Lc>) -> Lc {
|
||||
if let Some(b) = ext_bit {
|
||||
let left = self.shift_bv(x, y.clone(), None);
|
||||
let right = self.shift_bv(b.clone(), y, None) - 1;
|
||||
left + &self.mul(b, right)
|
||||
} else {
|
||||
y.into_iter()
|
||||
.enumerate()
|
||||
.fold(x, |x, (i, yi)| self.const_pow_shift_bv(&x, i, yi))
|
||||
}
|
||||
}
|
||||
|
||||
/// Shift `x` left by `y`, filling the blank spots with bit-valued `ext_bit`.
|
||||
/// Returns a bit sequence.
|
||||
fn shift_bv_bits(&mut self, x: Lc, y: Vec<Lc>, ext_bit: Option<Lc>, n: usize) -> Vec<Lc> {
|
||||
let s = self.shift_bv(x, y, ext_bit);
|
||||
let mut bits = self.bitify("shift", &s, 2 * n - 1, false);
|
||||
bits.truncate(n);
|
||||
bits
|
||||
}
|
||||
|
||||
fn embed_bv(&mut self, bv: Term) {
|
||||
//println!("Embed: {}", bv);
|
||||
if let Sort::BitVector(n) = check(bv.clone()).unwrap() {
|
||||
if !self.bvs.contains_key(&bv) {
|
||||
match &bv.op {
|
||||
@@ -313,47 +396,53 @@ impl ToR1cs {
|
||||
self.get_bv_bits(&bv);
|
||||
}
|
||||
}
|
||||
Op::Const(Value::BitVector(BitVector { uint, width })) => {
|
||||
let bit_lcs = (0..*width)
|
||||
.map(|i| self.r1cs.zero() + uint.get_bit(i as u32) as isize)
|
||||
Op::Const(Value::BitVector(b)) => {
|
||||
let bit_lcs = (0..b.width())
|
||||
.map(|i| self.r1cs.zero() + b.uint().get_bit(i as u32) as isize)
|
||||
.collect();
|
||||
self.set_bv_bits(bv, bit_lcs);
|
||||
}
|
||||
Op::Ite => {
|
||||
let c = self.get_bool(&bv.children[0]).clone();
|
||||
let t = self.get_bv_uint(&bv.children[1]).clone();
|
||||
let f = self.get_bv_uint(&bv.children[2]).clone();
|
||||
let c = self.get_bool(&bv.cs[0]).clone();
|
||||
let t = self.get_bv_uint(&bv.cs[1]).clone();
|
||||
let f = self.get_bv_uint(&bv.cs[2]).clone();
|
||||
let ite = self.ite(c, t, &f);
|
||||
self.set_bv_uint(bv, ite, n);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Not) => {
|
||||
let bits = self.get_bv_bits(&bv.children[0]).clone();
|
||||
let bits = self.get_bv_bits(&bv.cs[0]).clone();
|
||||
let not_bits = bits.iter().map(|bit| self.bool_not(bit)).collect();
|
||||
self.set_bv_bits(bv, not_bits);
|
||||
}
|
||||
Op::BvUnOp(BvUnOp::Neg) => {
|
||||
let x = self.r1cs.zero() + &Integer::from(2).pow(n as u32)
|
||||
- self.get_bv_uint(&bv.children[0]);
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
// Wrong for x == 0
|
||||
let almost_neg_x = self.r1cs.zero() + &Integer::from(2).pow(n as u32) - &x;
|
||||
let is_zero = self.is_zero(x);
|
||||
let neg_x = self.ite(is_zero, self.r1cs.zero(), &almost_neg_x);
|
||||
self.set_bv_uint(bv, neg_x, n);
|
||||
}
|
||||
Op::BvUext(_) => {
|
||||
// TODO: carry over bits if possible.
|
||||
let x = self.get_bv_uint(&bv.cs[0]).clone();
|
||||
self.set_bv_uint(bv, x, n);
|
||||
}
|
||||
Op::BvUext(extra_n) => {
|
||||
// TODO: carry over bits if possible.
|
||||
let x = self.get_bv_uint(&bv.children[0]).clone();
|
||||
let new_n = n + extra_n;
|
||||
self.set_bv_uint(bv, x, new_n);
|
||||
}
|
||||
Op::BvSext(extra_n) => {
|
||||
let mut bits = self.get_bv_bits(&bv.children[0]).clone().into_iter();
|
||||
let mut bits = self.get_bv_bits(&bv.cs[0]).clone().into_iter().rev();
|
||||
let ext_bits =
|
||||
std::iter::repeat(bits.next().expect("sign ext empty").clone())
|
||||
.take(extra_n + 1);
|
||||
|
||||
self.set_bv_bits(bv, ext_bits.chain(bits).collect());
|
||||
self.set_bv_bits(bv, bits.rev().chain(ext_bits).collect());
|
||||
}
|
||||
Op::BoolToBv => {
|
||||
let b = self.get_bool(&bv.cs[0]).clone();
|
||||
self.set_bv_bits(bv, vec![b]);
|
||||
}
|
||||
Op::BvNaryOp(o) => match o {
|
||||
BvNaryOp::Xor | BvNaryOp::Or | BvNaryOp::And => {
|
||||
let mut bits_by_bv = bv
|
||||
.children
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| self.get_bv_bits(c).clone())
|
||||
.collect::<Vec<_>>();
|
||||
@@ -363,6 +452,7 @@ impl ToR1cs {
|
||||
bits_by_bv.iter_mut().map(|bv| bv.pop().unwrap()).collect(),
|
||||
);
|
||||
}
|
||||
bits_bv_idx.reverse();
|
||||
let f = |v: Vec<Lc>| match o {
|
||||
BvNaryOp::And => self.nary_and(v.into_iter()),
|
||||
BvNaryOp::Or => self.nary_or(v.into_iter()),
|
||||
@@ -375,7 +465,7 @@ impl ToR1cs {
|
||||
BvNaryOp::Add | BvNaryOp::Mul => {
|
||||
let f_width = self.r1cs.modulus().significant_bits() as usize - 1;
|
||||
let values = bv
|
||||
.children
|
||||
.cs
|
||||
.iter()
|
||||
.map(|c| self.get_bv_uint(c).clone())
|
||||
.collect::<Vec<_>>();
|
||||
@@ -383,21 +473,21 @@ impl ToR1cs {
|
||||
BvNaryOp::Add => {
|
||||
let sum =
|
||||
values.into_iter().fold(self.r1cs.zero(), |s, v| s + &v);
|
||||
let extra_width = bitsize(bv.children.len()) - 1;
|
||||
let extra_width = bitsize(bv.cs.len().saturating_sub(1));
|
||||
(sum, n + extra_width)
|
||||
}
|
||||
BvNaryOp::Mul => {
|
||||
if bv.children.len() * n < f_width {
|
||||
if bv.cs.len() * n < f_width {
|
||||
let z = self.r1cs.zero() + 1;
|
||||
(
|
||||
values.into_iter().fold(z, |acc, v| self.mul(acc, v)),
|
||||
bv.children.len() * n,
|
||||
bv.cs.len() * n,
|
||||
)
|
||||
} else {
|
||||
let z = self.r1cs.zero() + 1;
|
||||
let p = values.into_iter().fold(z, |acc, v| {
|
||||
let p = self.mul(acc, v);
|
||||
let mut bits = self.bitify("binMul", &p, 2 * n);
|
||||
let mut bits = self.bitify("binMul", &p, 2 * n, false);
|
||||
bits.truncate(n);
|
||||
self.debitify(bits.into_iter(), false)
|
||||
});
|
||||
@@ -406,17 +496,94 @@ impl ToR1cs {
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut bits = self.bitify("arith", &res, width);
|
||||
let mut bits = self.bitify("arith", &res, width, false);
|
||||
bits.truncate(n);
|
||||
self.set_bv_bits(bv, bits);
|
||||
}
|
||||
},
|
||||
Op::BvBinOp(o) => unimplemented!(),
|
||||
_ => panic!("Non-boolean in embed_bool: {:?}", bv),
|
||||
Op::BvBinOp(o) => {
|
||||
let a = self.get_bv_uint(&bv.cs[0]);
|
||||
let b = self.get_bv_uint(&bv.cs[1]);
|
||||
match o {
|
||||
BvBinOp::Sub => {
|
||||
let sum = a.clone() + &(Integer::from(1) << n as u32) - b;
|
||||
let mut bits = self.bitify("sub", &sum, n + 1, false);
|
||||
bits.truncate(n);
|
||||
self.set_bv_bits(bv, bits);
|
||||
}
|
||||
BvBinOp::Udiv | BvBinOp::Urem => {
|
||||
let b = b.clone();
|
||||
let a = a.clone();
|
||||
let is_zero = self.is_zero(b.clone());
|
||||
let (q_v, r_v) = self
|
||||
.r1cs
|
||||
.eval(&a)
|
||||
.and_then(|a| {
|
||||
self.r1cs.eval(&b).map(|b| {
|
||||
if b == 0 {
|
||||
((Integer::from(1) << n as u32) - 1, a)
|
||||
} else {
|
||||
(a.clone() / &b, a % b)
|
||||
}
|
||||
})
|
||||
})
|
||||
.map(|(a, b)| (Some(a), Some(b)))
|
||||
.unwrap_or((None, None));
|
||||
let q = self.fresh_var("div_q", q_v);
|
||||
let r = self.fresh_var("div_q", r_v);
|
||||
let qb = self.bitify("div_q", &q, n, false);
|
||||
let rb = self.bitify("div_r", &q, n, false);
|
||||
self.r1cs.constraint(q.clone(), b.clone(), a - &r);
|
||||
let is_gt = self.bv_ge(b - 1, &r, n);
|
||||
let is_not_ge = self.bool_not(&is_gt);
|
||||
let is_not_zero = self.bool_not(&is_zero);
|
||||
self.r1cs
|
||||
.constraint(is_not_ge, is_not_zero, self.r1cs.zero());
|
||||
let bits = match o {
|
||||
BvBinOp::Udiv => qb,
|
||||
BvBinOp::Urem => rb,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.set_bv_bits(bv, bits);
|
||||
}
|
||||
// Shift cases
|
||||
_ => {
|
||||
let r = b.clone();
|
||||
let a = a.clone();
|
||||
let b = bitsize(n - 1);
|
||||
assert!(1 << b == n);
|
||||
let mut rb = self.get_bv_bits(&bv.cs[1]).clone();
|
||||
rb.truncate(b);
|
||||
let sum = self.debitify(rb.clone().into_iter(), false);
|
||||
self.assert_zero(sum - &r);
|
||||
let bits = match o {
|
||||
BvBinOp::Shl => self.shift_bv_bits(a, rb, None, n),
|
||||
BvBinOp::Lshr | BvBinOp::Ashr => {
|
||||
let mut lb = self.get_bv_bits(&bv.cs[0]).clone();
|
||||
lb.reverse();
|
||||
let ext_bit = match o {
|
||||
BvBinOp::Ashr => Some(lb.first().unwrap().clone()),
|
||||
_ => None,
|
||||
};
|
||||
let l = self.debitify(lb.into_iter(), false);
|
||||
let mut bits = self.shift_bv_bits(l, rb, ext_bit, n);
|
||||
bits.reverse();
|
||||
bits
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.set_bv_bits(bv, bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => panic!("Non-bv in embed_bv: {}", bv),
|
||||
}
|
||||
}
|
||||
// self.r1cs.eval(self.get_bv_uint(&bv2)).map(|v| {
|
||||
// println!("-> {:b}", v);
|
||||
// });
|
||||
} else {
|
||||
panic!("{:?} is not a bit-vector in embed_bv", bv);
|
||||
panic!("{} is not a bit-vector in embed_bv", bv);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -455,21 +622,28 @@ impl ToR1cs {
|
||||
&self.bvs.get(t).expect("Missing term").uint
|
||||
}
|
||||
|
||||
fn get_bv_signed_int(&mut self, t: &Term) -> Lc {
|
||||
let bits = self.get_bv_bits(t).clone();
|
||||
self.debitify(bits.into_iter(), true)
|
||||
}
|
||||
|
||||
fn get_bv_bits(&mut self, t: &Term) -> &Vec<Lc> {
|
||||
let mut bvs = std::mem::take(&mut self.bvs);
|
||||
let entry = bvs.get_mut(t).expect("Missing bit-vec");
|
||||
if entry.bits.len() == 0 {
|
||||
entry.bits = self.bitify("getbits", &entry.uint, entry.width);
|
||||
entry.bits = self.bitify("getbits", &entry.uint, entry.width, false);
|
||||
}
|
||||
self.bvs = bvs;
|
||||
&self.bvs.get(t).unwrap().bits
|
||||
}
|
||||
|
||||
fn assert_zero(&mut self, x: Lc) {
|
||||
self.r1cs.constraint(self.r1cs.zero(), self.r1cs.zero(), x);
|
||||
}
|
||||
fn assert(&mut self, t: Term) {
|
||||
self.embed(t.clone());
|
||||
let lc = self.get_bool(&t).clone();
|
||||
self.r1cs
|
||||
.constraint(self.r1cs.zero(), self.r1cs.zero(), lc - 1);
|
||||
self.assert_zero(lc - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -494,6 +668,7 @@ fn bitsize(mut n: usize) -> usize {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::ir::term::dist::*;
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
use quickcheck_macros::quickcheck;
|
||||
use rand::distributions::Distribution;
|
||||
@@ -526,7 +701,7 @@ mod test {
|
||||
impl Arbitrary for PureBool {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g));
|
||||
let t = BoolDist(g.size()).sample(&mut rng);
|
||||
let t = PureBoolDist(g.size()).sample(&mut rng);
|
||||
let values: HashMap<String, Value> = PostOrderIter::new(t.clone())
|
||||
.filter_map(|c| {
|
||||
if let Op::Var(n, _) = &c.op {
|
||||
@@ -553,7 +728,7 @@ mod test {
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn random_bool(PureBool(t, values): PureBool) {
|
||||
fn random_pure_bool(PureBool(t, values): PureBool) {
|
||||
let t = if eval(&t, &values).as_bool() {
|
||||
t
|
||||
} else {
|
||||
@@ -567,4 +742,76 @@ mod test {
|
||||
let r1cs = to_r1cs(cs, Integer::from(1014088787));
|
||||
r1cs.check_all();
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Bool(Term, HashMap<String, Value>);
|
||||
|
||||
impl Arbitrary for Bool {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g));
|
||||
let d = FixedSizeDist {
|
||||
bv_width: 8,
|
||||
size: g.size(),
|
||||
sort: Sort::Bool,
|
||||
};
|
||||
let t = d.sample(&mut rng);
|
||||
let values: HashMap<String, Value> = PostOrderIter::new(t.clone())
|
||||
.filter_map(|c| match &c.op {
|
||||
Op::Var(n, Sort::Bool) => Some((n.clone(), Value::Bool(bool::arbitrary(g)))),
|
||||
Op::Var(n, Sort::BitVector(w)) => Some((
|
||||
n.clone(),
|
||||
Value::BitVector(UniformBitVector(*w).sample(&mut rng)),
|
||||
)),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
Bool(t, values)
|
||||
}
|
||||
|
||||
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
|
||||
let vs = self.1.clone();
|
||||
let ts = PostOrderIter::new(self.0.clone()).collect::<Vec<_>>();
|
||||
|
||||
Box::new(
|
||||
ts.into_iter()
|
||||
.rev()
|
||||
.skip(1)
|
||||
.map(move |t| Bool(t, vs.clone())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[quickcheck]
|
||||
fn random_bool(Bool(t, values): Bool) {
|
||||
let v = eval(&t, &values);
|
||||
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
|
||||
let cs = Constraints {
|
||||
public_inputs: HashSet::new(),
|
||||
values: Some(values),
|
||||
assertions: vec![t],
|
||||
};
|
||||
let r1cs = to_r1cs(cs, Integer::from(1014088787));
|
||||
r1cs.check_all();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eq_test() {
|
||||
let cs = Constraints {
|
||||
public_inputs: vec!["a"].into_iter().map(|a| a.to_owned()).collect(),
|
||||
values: Some(
|
||||
vec![(
|
||||
"b".to_owned(),
|
||||
Value::BitVector(BitVector::new(Integer::from(152), 8)),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
assertions: vec![
|
||||
term![Op::Not; term![Op::Eq; leaf_term(Op::Const(Value::BitVector(BitVector::new(Integer::from(0b10110), 8)))),
|
||||
term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]],
|
||||
],
|
||||
};
|
||||
let r1cs = to_r1cs(cs, Integer::from(1014088787));
|
||||
r1cs.check_all();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
|
||||
#[macro_use]
|
||||
pub mod ir;
|
||||
pub mod util;
|
||||
pub mod target;
|
||||
pub mod util;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use rug::Integer;
|
||||
use rug::ops::{RemRounding, RemRoundingAssign};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Display;
|
||||
@@ -34,13 +35,13 @@ impl std::ops::AddAssign<&Lc> for Lc {
|
||||
fn add_assign(&mut self, other: &Lc) {
|
||||
assert_eq!(&self.modulus, &other.modulus);
|
||||
self.constant += &other.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (i, v) in &other.monomials {
|
||||
self.monomials
|
||||
.entry(*i)
|
||||
.and_modify(|u| {
|
||||
*u += v;
|
||||
*u %= &*other.modulus;
|
||||
u.rem_floor_assign(&*other.modulus);
|
||||
})
|
||||
.or_insert_with(|| v.clone());
|
||||
}
|
||||
@@ -58,7 +59,7 @@ impl std::ops::Add<&Integer> for Lc {
|
||||
impl std::ops::AddAssign<&Integer> for Lc {
|
||||
fn add_assign(&mut self, other: &Integer) {
|
||||
self.constant += other;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +74,7 @@ impl std::ops::Add<isize> for Lc {
|
||||
impl std::ops::AddAssign<isize> for Lc {
|
||||
fn add_assign(&mut self, other: isize) {
|
||||
self.constant += Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,13 +90,13 @@ impl std::ops::SubAssign<&Lc> for Lc {
|
||||
fn sub_assign(&mut self, other: &Lc) {
|
||||
assert_eq!(&self.modulus, &other.modulus);
|
||||
self.constant -= &other.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (i, v) in &other.monomials {
|
||||
self.monomials
|
||||
.entry(*i)
|
||||
.and_modify(|u| {
|
||||
*u -= v;
|
||||
*u %= &*other.modulus;
|
||||
u.rem_floor_assign(&*other.modulus);
|
||||
})
|
||||
.or_insert_with(|| -v.clone());
|
||||
}
|
||||
@@ -113,7 +114,7 @@ impl std::ops::Sub<&Integer> for Lc {
|
||||
impl std::ops::SubAssign<&Integer> for Lc {
|
||||
fn sub_assign(&mut self, other: &Integer) {
|
||||
self.constant -= other;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +129,7 @@ impl std::ops::Sub<isize> for Lc {
|
||||
impl std::ops::SubAssign<isize> for Lc {
|
||||
fn sub_assign(&mut self, other: isize) {
|
||||
self.constant -= Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,10 +137,10 @@ impl std::ops::Neg for Lc {
|
||||
type Output = Lc;
|
||||
fn neg(mut self) -> Lc {
|
||||
self.constant = -self.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= Integer::from(-1);
|
||||
*v %= &*self.modulus;
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
self
|
||||
}
|
||||
@@ -156,10 +157,10 @@ impl std::ops::Mul<&Integer> for Lc {
|
||||
impl std::ops::MulAssign<&Integer> for Lc {
|
||||
fn mul_assign(&mut self, other: &Integer) {
|
||||
self.constant *= other;
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= other;
|
||||
*v %= &*self.modulus;
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,10 +176,10 @@ impl std::ops::Mul<isize> for Lc {
|
||||
impl std::ops::MulAssign<isize> for Lc {
|
||||
fn mul_assign(&mut self, other: isize) {
|
||||
self.constant *= Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
self.constant.rem_floor_assign(&*self.modulus);
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= Integer::from(other);
|
||||
*v %= &*self.modulus;
|
||||
v.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,7 +254,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
let sign = |i: &Integer| if i < &half_m { "+" } else { "-" };
|
||||
let format_i = |i: &Integer| format!("{}{}", sign(i), abs(i));
|
||||
|
||||
s.extend(format_i(&a.constant).chars());
|
||||
s.extend(format_i(&Integer::from(&a.constant)).chars());
|
||||
for (idx, coeff) in &a.monomials {
|
||||
s.extend(
|
||||
format!(
|
||||
@@ -271,7 +272,8 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
let av = self.eval(a).unwrap();
|
||||
let bv = self.eval(b).unwrap();
|
||||
let cv = self.eval(c).unwrap();
|
||||
if &(av.clone() * &bv % &*self.modulus) != &cv {
|
||||
dbg!(&av, &bv, &cv, &self.modulus);
|
||||
if &((av.clone() * &bv).rem_floor(&*self.modulus)) != &cv {
|
||||
panic!(
|
||||
"Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})",
|
||||
self.format_lc(a),
|
||||
@@ -292,7 +294,7 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
.expect("Missing value in R1cs::eval")
|
||||
.clone();
|
||||
acc += val * coeff;
|
||||
acc %= &*self.modulus;
|
||||
acc.rem_floor_assign(&*self.modulus);
|
||||
}
|
||||
acc
|
||||
})
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
|
||||
Reference in New Issue
Block a user