mirror of
https://github.com/circify/circ.git
synced 2026-01-10 22:27:55 -05:00
Started IR->R1cs
This commit is contained in:
@@ -1 +1,2 @@
|
||||
pub mod term;
|
||||
pub mod to_r1cs;
|
||||
|
||||
127
src/ir/term.rs
127
src/ir/term.rs
@@ -10,7 +10,7 @@ pub enum Op {
|
||||
Eq,
|
||||
Let(String),
|
||||
Var(String, Sort),
|
||||
Const(Const),
|
||||
Const(Value),
|
||||
|
||||
BvBinOp(BvBinOp),
|
||||
BvBinPred(BvBinPred),
|
||||
@@ -210,7 +210,7 @@ pub struct BitVector {
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Const {
|
||||
pub enum Value {
|
||||
BitVector(BitVector),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
@@ -219,16 +219,16 @@ pub enum Const {
|
||||
Bool(bool),
|
||||
}
|
||||
|
||||
impl std::cmp::Eq for Const {}
|
||||
impl std::hash::Hash for Const {
|
||||
impl std::cmp::Eq for Value {}
|
||||
impl std::hash::Hash for Value {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Const::BitVector(bv) => bv.hash(state),
|
||||
Const::F32(bv) => bv.to_bits().hash(state),
|
||||
Const::F64(bv) => bv.to_bits().hash(state),
|
||||
Const::Int(bv) => bv.hash(state),
|
||||
Const::Field(bv) => bv.hash(state),
|
||||
Const::Bool(bv) => bv.hash(state),
|
||||
Value::BitVector(bv) => bv.hash(state),
|
||||
Value::F32(bv) => bv.to_bits().hash(state),
|
||||
Value::F64(bv) => bv.to_bits().hash(state),
|
||||
Value::Int(bv) => bv.hash(state),
|
||||
Value::Field(bv) => bv.hash(state),
|
||||
Value::Bool(bv) => bv.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,15 +256,15 @@ lazy_static! {
|
||||
static ref TERM_TYPES: RwLock<HashMap<TTerm, Sort>> = RwLock::new(HashMap::new());
|
||||
}
|
||||
|
||||
impl Const {
|
||||
impl Value {
|
||||
fn sort(&self) -> Sort {
|
||||
match &self {
|
||||
Const::Bool(_) => Sort::Bool,
|
||||
Const::Field(f) => Sort::Field(f.modulus.clone()),
|
||||
Const::Int(_) => Sort::Int,
|
||||
Const::F64(_) => Sort::F64,
|
||||
Const::F32(_) => Sort::F32,
|
||||
Const::BitVector(b) => Sort::BitVector(b.width),
|
||||
Value::Bool(_) => Sort::Bool,
|
||||
Value::Field(f) => Sort::Field(f.modulus.clone()),
|
||||
Value::Int(_) => Sort::Int,
|
||||
Value::F64(_) => Sort::F64,
|
||||
Value::F32(_) => Sort::F32,
|
||||
Value::BitVector(b) => Sort::BitVector(b.width),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,7 +529,6 @@ macro_rules! term {
|
||||
// All subterms are booleans.
|
||||
pub struct BoolDist(pub usize);
|
||||
|
||||
|
||||
// A distribution of n usizes that sum to this value.
|
||||
// (n, sum)
|
||||
pub struct FixedAdditionPartition(usize, usize);
|
||||
@@ -560,29 +559,71 @@ impl rand::distributions::Distribution<Term> for BoolDist {
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::seq::SliceRandom;
|
||||
let ops = &[
|
||||
Op::Const(Const::Bool(rng.gen())),
|
||||
Op::Const(Value::Bool(rng.gen())),
|
||||
Op::Var(Alphanumeric.sample(rng).to_string(), Sort::Bool),
|
||||
Op::BoolUnOp(BoolUnOp::Not),
|
||||
Op::BoolBinOp(BoolBinOp::Implies),
|
||||
Op::BoolNaryOp(BoolNaryOp::Or),
|
||||
Op::BoolNaryOp(BoolNaryOp::And),
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor)
|
||||
Op::BoolNaryOp(BoolNaryOp::Xor),
|
||||
];
|
||||
let o = match self.0 {
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
1 => ops[..2].choose(rng), // arity 0
|
||||
2 => ops[2..3].choose(rng), // arity 1
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}.unwrap().clone();
|
||||
_ => ops[2..].choose(rng), // others
|
||||
}
|
||||
.unwrap()
|
||||
.clone();
|
||||
// Now, self.0 is a least arity+1
|
||||
let a = o.arity().unwrap_or_else(|| rng.gen_range(2..self.0));
|
||||
let excess = self.0-1-a;
|
||||
let excess = self.0 - 1 - a;
|
||||
let ns = FixedAdditionPartition(a, excess).sample(rng);
|
||||
let subterms = ns.into_iter().map(|n| BoolDist(n+1).sample(rng)).collect::<Vec<_>>();
|
||||
let subterms = ns
|
||||
.into_iter()
|
||||
.map(|n| BoolDist(n + 1).sample(rng))
|
||||
.collect::<Vec<_>>();
|
||||
term(o, subterms)
|
||||
}
|
||||
}
|
||||
|
||||
pub type TermMap<T> = hashconsing::coll::HConMap<Term, T>;
|
||||
pub type TermSet = hashconsing::coll::HConSet<Term>;
|
||||
|
||||
pub struct PostOrderIter {
|
||||
// (children stacked, term)
|
||||
stack: Vec<(bool, Term)>,
|
||||
visited: TermSet,
|
||||
}
|
||||
|
||||
impl PostOrderIter {
|
||||
pub fn new(t: Term) -> Self {
|
||||
Self {
|
||||
stack: vec![(false, t)],
|
||||
visited: TermSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::iter::Iterator for PostOrderIter {
|
||||
type Item = Term;
|
||||
fn next(&mut self) -> Option<Term> {
|
||||
while let Some((children_pushed, t)) = self.stack.last() {
|
||||
if self.visited.contains(&t) {
|
||||
self.stack.pop();
|
||||
} else if !children_pushed {
|
||||
self.stack.last_mut().unwrap().0 = true;
|
||||
let cs = self.stack.last().unwrap().1.children.clone();
|
||||
self.stack.extend(cs.into_iter().map(|c| (false, c)));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.stack.pop().map(|(_, t)| {
|
||||
self.visited.insert(t.clone());
|
||||
t
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
@@ -601,21 +642,43 @@ mod test {
|
||||
mod type_ {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn vars() {
|
||||
let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool));
|
||||
assert_eq!(check(v), Ok(Sort::Bool));
|
||||
fn t() -> Term {
|
||||
let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4)));
|
||||
assert_eq!(check(v.clone()), Ok(Sort::BitVector(4)));
|
||||
let v = term![
|
||||
term![
|
||||
Op::BvBit(4);
|
||||
term![
|
||||
Op::BvConcat;
|
||||
v,
|
||||
term![Op::BoolToBv; leaf_term(Op::Var("c".to_owned(), Sort::Bool))]
|
||||
]
|
||||
];
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vars() {
|
||||
let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool));
|
||||
assert_eq!(check(v), Ok(Sort::Bool));
|
||||
let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4)));
|
||||
assert_eq!(check(v.clone()), Ok(Sort::BitVector(4)));
|
||||
let v = t();
|
||||
assert_eq!(check(v), Ok(Sort::Bool));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn traversal() {
|
||||
let tt = t();
|
||||
assert_eq!(
|
||||
vec![
|
||||
Op::Var("c".to_owned(), Sort::Bool),
|
||||
Op::BoolToBv,
|
||||
Op::Var("b".to_owned(), Sort::BitVector(4)),
|
||||
Op::BvConcat,
|
||||
Op::BvBit(4),
|
||||
],
|
||||
PostOrderIter::new(tt)
|
||||
.map(|t| t.op.clone())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
67
src/ir/to_r1cs.rs
Normal file
67
src/ir/to_r1cs.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use super::term::*;
|
||||
use crate::target::r1cs::*;
|
||||
|
||||
use rug::Integer;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
struct ToR1cs {
|
||||
r1cs: R1cs<String>,
|
||||
bools: TermMap<Lc>,
|
||||
values: Option<HashMap<String, Value>>,
|
||||
next_idx: usize,
|
||||
}
|
||||
|
||||
impl ToR1cs {
|
||||
fn new(modulus: Integer, values: Option<HashMap<String, Value>>) -> Self {
|
||||
Self {
|
||||
r1cs: R1cs::new(modulus, values.is_some()),
|
||||
bools: TermMap::new(),
|
||||
values,
|
||||
next_idx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn fresh_var<D: Display>(&mut self, ctx: &D) -> Lc {
|
||||
let n = format!("{}v{}", ctx, self.next_idx);
|
||||
self.next_idx += 1;
|
||||
self.r1cs.add_signal(n.clone());
|
||||
self.r1cs.signal_lc(&n)
|
||||
}
|
||||
|
||||
fn embed(&mut self, t: Term) {
|
||||
for c in PostOrderIter::new(t) {
|
||||
match check(c.clone()).expect("type-check error in embed") {
|
||||
Sort::Bool => {
|
||||
self.embed_bool(c);
|
||||
}
|
||||
s => panic!("Unsupported sort in embed: {:?}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_bit(&mut self, b: &Lc) {
|
||||
self.r1cs
|
||||
.constraint(b.clone(), b.clone() - 1, self.r1cs.zero());
|
||||
}
|
||||
|
||||
fn embed_bool(&mut self, t: Term) -> &Lc {
|
||||
debug_assert!(check(t.clone()) == Ok(Sort::Bool));
|
||||
// TODO: skip if already embedded
|
||||
for c in PostOrderIter::new(t.clone()) {
|
||||
if !self.bools.contains_key(&c) {
|
||||
let lc = match &c.op {
|
||||
Op::Var(name, Sort::Bool) => {
|
||||
let v = self.fresh_var(name);
|
||||
self.enforce_bit(&v);
|
||||
v
|
||||
}
|
||||
_ => panic!("Non-boolean in embed_bool: {:?}", c),
|
||||
};
|
||||
self.bools.insert(c, lc);
|
||||
}
|
||||
}
|
||||
self.bools.get(&t).unwrap()
|
||||
}
|
||||
}
|
||||
@@ -24,8 +24,16 @@ pub struct Lc {
|
||||
impl std::ops::Add<&Lc> for Lc {
|
||||
type Output = Lc;
|
||||
fn add(mut self, other: &Lc) -> Lc {
|
||||
self += other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign<&Lc> for Lc {
|
||||
fn add_assign(&mut self, other: &Lc) {
|
||||
assert_eq!(&self.modulus, &other.modulus);
|
||||
self.constant = (self.constant + &other.constant) % &*self.modulus;
|
||||
self.constant += &other.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
for (i, v) in &other.monomials {
|
||||
self.monomials
|
||||
.entry(*i)
|
||||
@@ -35,14 +43,103 @@ impl std::ops::Add<&Lc> for Lc {
|
||||
})
|
||||
.or_insert_with(|| v.clone());
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<&Integer> for Lc {
|
||||
type Output = Lc;
|
||||
fn add(mut self, other: &Integer) -> Lc {
|
||||
self.constant = (self.constant + other) % &*self.modulus;
|
||||
self += other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign<&Integer> for Lc {
|
||||
fn add_assign(&mut self, other: &Integer) {
|
||||
self.constant += other;
|
||||
self.constant %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<isize> for Lc {
|
||||
type Output = Lc;
|
||||
fn add(mut self, other: isize) -> Lc {
|
||||
self += other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign<isize> for Lc {
|
||||
fn add_assign(&mut self, other: isize) {
|
||||
self.constant += Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&Lc> for Lc {
|
||||
type Output = Lc;
|
||||
fn sub(mut self, other: &Lc) -> Lc {
|
||||
self -= other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::SubAssign<&Lc> for Lc {
|
||||
fn sub_assign(&mut self, other: &Lc) {
|
||||
assert_eq!(&self.modulus, &other.modulus);
|
||||
self.constant -= &other.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
for (i, v) in &other.monomials {
|
||||
self.monomials
|
||||
.entry(*i)
|
||||
.and_modify(|u| {
|
||||
*u -= v;
|
||||
*u %= &*other.modulus;
|
||||
})
|
||||
.or_insert_with(|| v.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<&Integer> for Lc {
|
||||
type Output = Lc;
|
||||
fn sub(mut self, other: &Integer) -> Lc {
|
||||
self -= other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::SubAssign<&Integer> for Lc {
|
||||
fn sub_assign(&mut self, other: &Integer) {
|
||||
self.constant -= other;
|
||||
self.constant %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub<isize> for Lc {
|
||||
type Output = Lc;
|
||||
fn sub(mut self, other: isize) -> Lc {
|
||||
self -= other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::SubAssign<isize> for Lc {
|
||||
fn sub_assign(&mut self, other: isize) {
|
||||
self.constant -= Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Neg for Lc {
|
||||
type Output = Lc;
|
||||
fn neg(mut self) -> Lc {
|
||||
self.constant = -self.constant;
|
||||
self.constant %= &*self.modulus;
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= Integer::from(-1);
|
||||
*v %= &*self.modulus;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -50,15 +147,41 @@ impl std::ops::Add<&Integer> for Lc {
|
||||
impl std::ops::Mul<&Integer> for Lc {
|
||||
type Output = Lc;
|
||||
fn mul(mut self, other: &Integer) -> Lc {
|
||||
self.constant = (self.constant * other) % &*self.modulus;
|
||||
self *= other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::MulAssign<&Integer> for Lc {
|
||||
fn mul_assign(&mut self, other: &Integer) {
|
||||
self.constant *= other;
|
||||
self.constant %= &*self.modulus;
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= other;
|
||||
*v %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<isize> for Lc {
|
||||
type Output = Lc;
|
||||
fn mul(mut self, other: isize) -> Lc {
|
||||
self *= other;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::MulAssign<isize> for Lc {
|
||||
fn mul_assign(&mut self, other: isize) {
|
||||
self.constant *= Integer::from(other);
|
||||
self.constant %= &*self.modulus;
|
||||
for (_, v) in &mut self.monomials {
|
||||
*v *= Integer::from(other);
|
||||
*v %= &*self.modulus;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Clone + Hash + Eq> R1cs<S> {
|
||||
pub fn new(modulus: Integer, values: bool) -> Self {
|
||||
R1cs {
|
||||
@@ -117,10 +240,7 @@ impl<S: Clone + Hash + Eq> R1cs<S> {
|
||||
.signal_idxs
|
||||
.get(s)
|
||||
.expect("Missing signal in signal_lc");
|
||||
self.values
|
||||
.as_mut()
|
||||
.expect("Missing values")
|
||||
.insert(idx, v);
|
||||
self.values.as_mut().expect("Missing values").insert(idx, v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user