Started IR->R1cs

This commit is contained in:
Alex Ozdemir
2021-02-04 02:04:57 -08:00
parent fe7280b10e
commit 711f97503a
4 changed files with 291 additions and 40 deletions

View File

@@ -1 +1,2 @@
pub mod term;
pub mod to_r1cs;

View File

@@ -10,7 +10,7 @@ pub enum Op {
Eq,
Let(String),
Var(String, Sort),
Const(Const),
Const(Value),
BvBinOp(BvBinOp),
BvBinPred(BvBinPred),
@@ -210,7 +210,7 @@ pub struct BitVector {
}
#[derive(Clone, PartialEq, Debug)]
pub enum Const {
pub enum Value {
BitVector(BitVector),
F32(f32),
F64(f64),
@@ -219,16 +219,16 @@ pub enum Const {
Bool(bool),
}
impl std::cmp::Eq for Const {}
impl std::hash::Hash for Const {
impl std::cmp::Eq for Value {}
impl std::hash::Hash for Value {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Const::BitVector(bv) => bv.hash(state),
Const::F32(bv) => bv.to_bits().hash(state),
Const::F64(bv) => bv.to_bits().hash(state),
Const::Int(bv) => bv.hash(state),
Const::Field(bv) => bv.hash(state),
Const::Bool(bv) => bv.hash(state),
Value::BitVector(bv) => bv.hash(state),
Value::F32(bv) => bv.to_bits().hash(state),
Value::F64(bv) => bv.to_bits().hash(state),
Value::Int(bv) => bv.hash(state),
Value::Field(bv) => bv.hash(state),
Value::Bool(bv) => bv.hash(state),
}
}
}
@@ -256,15 +256,15 @@ lazy_static! {
static ref TERM_TYPES: RwLock<HashMap<TTerm, Sort>> = RwLock::new(HashMap::new());
}
impl Const {
impl Value {
fn sort(&self) -> Sort {
match &self {
Const::Bool(_) => Sort::Bool,
Const::Field(f) => Sort::Field(f.modulus.clone()),
Const::Int(_) => Sort::Int,
Const::F64(_) => Sort::F64,
Const::F32(_) => Sort::F32,
Const::BitVector(b) => Sort::BitVector(b.width),
Value::Bool(_) => Sort::Bool,
Value::Field(f) => Sort::Field(f.modulus.clone()),
Value::Int(_) => Sort::Int,
Value::F64(_) => Sort::F64,
Value::F32(_) => Sort::F32,
Value::BitVector(b) => Sort::BitVector(b.width),
}
}
}
@@ -529,7 +529,6 @@ macro_rules! term {
// All subterms are booleans.
pub struct BoolDist(pub usize);
// A distribution of n usizes that sum to this value.
// (n, sum)
pub struct FixedAdditionPartition(usize, usize);
@@ -560,29 +559,71 @@ impl rand::distributions::Distribution<Term> for BoolDist {
use rand::distributions::Alphanumeric;
use rand::seq::SliceRandom;
let ops = &[
Op::Const(Const::Bool(rng.gen())),
Op::Const(Value::Bool(rng.gen())),
Op::Var(Alphanumeric.sample(rng).to_string(), Sort::Bool),
Op::BoolUnOp(BoolUnOp::Not),
Op::BoolBinOp(BoolBinOp::Implies),
Op::BoolNaryOp(BoolNaryOp::Or),
Op::BoolNaryOp(BoolNaryOp::And),
Op::BoolNaryOp(BoolNaryOp::Xor)
Op::BoolNaryOp(BoolNaryOp::Xor),
];
let o = match self.0 {
1 => ops[..2].choose(rng), // arity 0
1 => ops[..2].choose(rng), // arity 0
2 => ops[2..3].choose(rng), // arity 1
_ => ops[2..].choose(rng), // others
}.unwrap().clone();
_ => ops[2..].choose(rng), // others
}
.unwrap()
.clone();
// Now, self.0 is a least arity+1
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
let excess = self.0-1-a;
let excess = self.0 - 1 - a;
let ns = FixedAdditionPartition(a, excess).sample(rng);
let subterms = ns.into_iter().map(|n| BoolDist(n+1).sample(rng)).collect::<Vec<_>>();
let subterms = ns
.into_iter()
.map(|n| BoolDist(n + 1).sample(rng))
.collect::<Vec<_>>();
term(o, subterms)
}
}
pub type TermMap<T> = hashconsing::coll::HConMap<Term, T>;
pub type TermSet = hashconsing::coll::HConSet<Term>;
pub struct PostOrderIter {
// (children stacked, term)
stack: Vec<(bool, Term)>,
visited: TermSet,
}
impl PostOrderIter {
pub fn new(t: Term) -> Self {
Self {
stack: vec![(false, t)],
visited: TermSet::new(),
}
}
}
impl std::iter::Iterator for PostOrderIter {
type Item = Term;
fn next(&mut self) -> Option<Term> {
while let Some((children_pushed, t)) = self.stack.last() {
if self.visited.contains(&t) {
self.stack.pop();
} else if !children_pushed {
self.stack.last_mut().unwrap().0 = true;
let cs = self.stack.last().unwrap().1.children.clone();
self.stack.extend(cs.into_iter().map(|c| (false, c)));
} else {
break;
}
}
self.stack.pop().map(|(_, t)| {
self.visited.insert(t.clone());
t
})
}
}
#[cfg(test)]
mod test {
@@ -601,21 +642,43 @@ mod test {
mod type_ {
use super::*;
#[test]
fn vars() {
let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool));
assert_eq!(check(v), Ok(Sort::Bool));
fn t() -> Term {
let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4)));
assert_eq!(check(v.clone()), Ok(Sort::BitVector(4)));
let v = term![
term![
Op::BvBit(4);
term![
Op::BvConcat;
v,
term![Op::BoolToBv; leaf_term(Op::Var("c".to_owned(), Sort::Bool))]
]
];
]
}
#[test]
fn vars() {
let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool));
assert_eq!(check(v), Ok(Sort::Bool));
let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4)));
assert_eq!(check(v.clone()), Ok(Sort::BitVector(4)));
let v = t();
assert_eq!(check(v), Ok(Sort::Bool));
}
#[test]
fn traversal() {
let tt = t();
assert_eq!(
vec![
Op::Var("c".to_owned(), Sort::Bool),
Op::BoolToBv,
Op::Var("b".to_owned(), Sort::BitVector(4)),
Op::BvConcat,
Op::BvBit(4),
],
PostOrderIter::new(tt)
.map(|t| t.op.clone())
.collect::<Vec<_>>()
);
}
}
}

67
src/ir/to_r1cs.rs Normal file
View File

@@ -0,0 +1,67 @@
use super::term::*;
use crate::target::r1cs::*;
use rug::Integer;
use std::collections::HashMap;
use std::fmt::Display;
struct ToR1cs {
r1cs: R1cs<String>,
bools: TermMap<Lc>,
values: Option<HashMap<String, Value>>,
next_idx: usize,
}
impl ToR1cs {
fn new(modulus: Integer, values: Option<HashMap<String, Value>>) -> Self {
Self {
r1cs: R1cs::new(modulus, values.is_some()),
bools: TermMap::new(),
values,
next_idx: 0,
}
}
fn fresh_var<D: Display>(&mut self, ctx: &D) -> Lc {
let n = format!("{}v{}", ctx, self.next_idx);
self.next_idx += 1;
self.r1cs.add_signal(n.clone());
self.r1cs.signal_lc(&n)
}
fn embed(&mut self, t: Term) {
for c in PostOrderIter::new(t) {
match check(c.clone()).expect("type-check error in embed") {
Sort::Bool => {
self.embed_bool(c);
}
s => panic!("Unsupported sort in embed: {:?}", s),
}
}
}
fn enforce_bit(&mut self, b: &Lc) {
self.r1cs
.constraint(b.clone(), b.clone() - 1, self.r1cs.zero());
}
fn embed_bool(&mut self, t: Term) -> &Lc {
debug_assert!(check(t.clone()) == Ok(Sort::Bool));
// TODO: skip if already embedded
for c in PostOrderIter::new(t.clone()) {
if !self.bools.contains_key(&c) {
let lc = match &c.op {
Op::Var(name, Sort::Bool) => {
let v = self.fresh_var(name);
self.enforce_bit(&v);
v
}
_ => panic!("Non-boolean in embed_bool: {:?}", c),
};
self.bools.insert(c, lc);
}
}
self.bools.get(&t).unwrap()
}
}

View File

@@ -24,8 +24,16 @@ pub struct Lc {
impl std::ops::Add<&Lc> for Lc {
type Output = Lc;
fn add(mut self, other: &Lc) -> Lc {
self += other;
self
}
}
impl std::ops::AddAssign<&Lc> for Lc {
fn add_assign(&mut self, other: &Lc) {
assert_eq!(&self.modulus, &other.modulus);
self.constant = (self.constant + &other.constant) % &*self.modulus;
self.constant += &other.constant;
self.constant %= &*self.modulus;
for (i, v) in &other.monomials {
self.monomials
.entry(*i)
@@ -35,14 +43,103 @@ impl std::ops::Add<&Lc> for Lc {
})
.or_insert_with(|| v.clone());
}
self
}
}
impl std::ops::Add<&Integer> for Lc {
type Output = Lc;
fn add(mut self, other: &Integer) -> Lc {
self.constant = (self.constant + other) % &*self.modulus;
self += other;
self
}
}
impl std::ops::AddAssign<&Integer> for Lc {
fn add_assign(&mut self, other: &Integer) {
self.constant += other;
self.constant %= &*self.modulus;
}
}
impl std::ops::Add<isize> for Lc {
type Output = Lc;
fn add(mut self, other: isize) -> Lc {
self += other;
self
}
}
impl std::ops::AddAssign<isize> for Lc {
fn add_assign(&mut self, other: isize) {
self.constant += Integer::from(other);
self.constant %= &*self.modulus;
}
}
impl std::ops::Sub<&Lc> for Lc {
type Output = Lc;
fn sub(mut self, other: &Lc) -> Lc {
self -= other;
self
}
}
impl std::ops::SubAssign<&Lc> for Lc {
fn sub_assign(&mut self, other: &Lc) {
assert_eq!(&self.modulus, &other.modulus);
self.constant -= &other.constant;
self.constant %= &*self.modulus;
for (i, v) in &other.monomials {
self.monomials
.entry(*i)
.and_modify(|u| {
*u -= v;
*u %= &*other.modulus;
})
.or_insert_with(|| v.clone());
}
}
}
impl std::ops::Sub<&Integer> for Lc {
type Output = Lc;
fn sub(mut self, other: &Integer) -> Lc {
self -= other;
self
}
}
impl std::ops::SubAssign<&Integer> for Lc {
fn sub_assign(&mut self, other: &Integer) {
self.constant -= other;
self.constant %= &*self.modulus;
}
}
impl std::ops::Sub<isize> for Lc {
type Output = Lc;
fn sub(mut self, other: isize) -> Lc {
self -= other;
self
}
}
impl std::ops::SubAssign<isize> for Lc {
fn sub_assign(&mut self, other: isize) {
self.constant -= Integer::from(other);
self.constant %= &*self.modulus;
}
}
impl std::ops::Neg for Lc {
type Output = Lc;
fn neg(mut self) -> Lc {
self.constant = -self.constant;
self.constant %= &*self.modulus;
for (_, v) in &mut self.monomials {
*v *= Integer::from(-1);
*v %= &*self.modulus;
}
self
}
}
@@ -50,15 +147,41 @@ impl std::ops::Add<&Integer> for Lc {
impl std::ops::Mul<&Integer> for Lc {
type Output = Lc;
fn mul(mut self, other: &Integer) -> Lc {
self.constant = (self.constant * other) % &*self.modulus;
self *= other;
self
}
}
impl std::ops::MulAssign<&Integer> for Lc {
fn mul_assign(&mut self, other: &Integer) {
self.constant *= other;
self.constant %= &*self.modulus;
for (_, v) in &mut self.monomials {
*v *= other;
*v %= &*self.modulus;
}
}
}
impl std::ops::Mul<isize> for Lc {
type Output = Lc;
fn mul(mut self, other: isize) -> Lc {
self *= other;
self
}
}
impl std::ops::MulAssign<isize> for Lc {
fn mul_assign(&mut self, other: isize) {
self.constant *= Integer::from(other);
self.constant %= &*self.modulus;
for (_, v) in &mut self.monomials {
*v *= Integer::from(other);
*v %= &*self.modulus;
}
}
}
impl<S: Clone + Hash + Eq> R1cs<S> {
pub fn new(modulus: Integer, values: bool) -> Self {
R1cs {
@@ -117,10 +240,7 @@ impl<S: Clone + Hash + Eq> R1cs<S> {
.signal_idxs
.get(s)
.expect("Missing signal in signal_lc");
self.values
.as_mut()
.expect("Missing values")
.insert(idx, v);
self.values.as_mut().expect("Missing values").insert(idx, v);
}
}