mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
Oblivious array elim
This commit is contained in:
@@ -1,3 +1,14 @@
|
||||
//! Linear Memory implementation.
|
||||
//!
|
||||
//! The idea is to replace each array with a term sequence and use ITEs to linearly scan the array
|
||||
//! when needed. A SELECT produces an ITE reduce chain, a STORE produces an ITE map over the
|
||||
//! sequence.
|
||||
//!
|
||||
//! E.g., for length-3 arrays.
|
||||
//!
|
||||
//! (select A k) => (ite (= k 2) A2 (ite (= k 1) A1 A0))
|
||||
//! (store A k v) => (ite (= k 0) v A0), (ite (= k 1) v A1), (ite (= k 2))
|
||||
|
||||
use super::visit::MemVisitor;
|
||||
use crate::ir::term::*;
|
||||
|
||||
@@ -88,12 +99,14 @@ impl MemVisitor for ArrayLinearizer {
|
||||
}
|
||||
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
|
||||
if let Sort::Array(_k, v, size) = s {
|
||||
self.sequences.insert(
|
||||
orig.clone(),
|
||||
(0..*size)
|
||||
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
|
||||
.collect(),
|
||||
);
|
||||
if *size <= self.size_thresh {
|
||||
self.sequences.insert(
|
||||
orig.clone(),
|
||||
(0..*size)
|
||||
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unreachable!("should only visit array vars")
|
||||
}
|
||||
@@ -108,15 +121,10 @@ pub fn linearize(t: &Term, size_thresh: usize) -> Term {
|
||||
pass.traverse(t)
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
fn v_bv(n: &str, w: usize) -> Term {
|
||||
leaf_term(Op::Var(n.to_owned(), Sort::BitVector(w)))
|
||||
}
|
||||
|
||||
fn bv(u: usize, w: usize) -> Term {
|
||||
leaf_term(Op::Const(Value::BitVector(BitVector::new(
|
||||
Integer::from(u),
|
||||
@@ -127,18 +135,20 @@ mod test {
|
||||
fn array_free(t: &Term) -> bool {
|
||||
for c in PostOrderIter::new(t.clone()) {
|
||||
if let Sort::Array(..) = check(c).unwrap() {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn count_ites(t: &Term) -> usize {
|
||||
PostOrderIter::new(t.clone()).filter(|t| &t.op == &Op::Ite).count()
|
||||
PostOrderIter::new(t.clone())
|
||||
.filter(|t| &t.op == &Op::Ite)
|
||||
.count()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doesnt_crash() {
|
||||
fn select_ite_stores() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv(0, 4)];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
mod visit;
|
||||
pub mod lin;
|
||||
pub mod obliv;
|
||||
|
||||
284
src/ir/opt/mem/obliv.rs
Normal file
284
src/ir/opt/mem/obliv.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
//! Oblivious Array Elimination
|
||||
//!
|
||||
//! This module attempts to identify *oblivious* arrays: those that are only accessed at constant
|
||||
//! indices[1]. These arrays can be replaced with normal terms.
|
||||
//!
|
||||
//! It operates in two passes:
|
||||
//!
|
||||
//! 1. determine which arrays are oblivious
|
||||
//! 2. replace oblivious arrays with (haskell) lists of terms.
|
||||
//!
|
||||
//!
|
||||
//! ## Pass 1: Identifying oblivious arrays
|
||||
//!
|
||||
//! We maintain a set of non-oblivious arrays, initially empty. We traverse the whole SMT
|
||||
//! constraint system, performing the following inferences:
|
||||
//!
|
||||
//! * If `a[i]` for non-constant `i`, then `a` is not oblivious
|
||||
//! * If `a[i\v]` for non-constant `i`, then neither `a[i\v]` nor `a` are oblivious
|
||||
//! * If `a[i\v]`, then `a[i\v]` and `a` are equi-oblivious
|
||||
//! * If `ite(c,a,b)`, then `ite(c,a,b)`, `a`, and `b` are equi-oblivious
|
||||
//! * If `a=b`, then `a` and `b` are equi-oblivious
|
||||
//!
|
||||
//! This procedure is iterated to fixpoint.
|
||||
//!
|
||||
//! ## Pass 2: Replacing oblivious arrays with term lists.
|
||||
//!
|
||||
//! In this pass, the goal is to
|
||||
//!
|
||||
//! * map array terms to haskell lists of value terms
|
||||
//! * map array selections to specific value terms
|
||||
//!
|
||||
//! The pass maintains:
|
||||
//!
|
||||
//! * a map from array terms to lists of values
|
||||
//!
|
||||
//! It then does a bottom-up formula traversal, performing the following
|
||||
//! transformations:
|
||||
//!
|
||||
//! * oblivious array variables are mapped to a list of (derivative) variables
|
||||
//! * oblivious constant arrays are mapped to a list that replicates the constant
|
||||
//! * accesses to oblivious arrays are mapped to the appropriate term from the
|
||||
//! value list of the array
|
||||
//! * stores to oblivious arrays are mapped to updated value lists
|
||||
//! * equalities between oblivious arrays are mapped to conjunctions of equalities
|
||||
//!
|
||||
//! [1]: Our use of "oblivious" is inspired by *oblivious turing
|
||||
//! machines* <https://en.wikipedia.org/wiki/Turing_machine_equivalents#Oblivious_Turing_machines>
|
||||
//! which access their tape in a data-indepedant way.
|
||||
|
||||
use super::visit::*;
|
||||
use crate::ir::term::*;
|
||||
|
||||
use std::iter::repeat;
|
||||
|
||||
struct NonOblivComputer {
|
||||
not_obliv: TermSet,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
impl NonOblivComputer {
|
||||
fn mark(&mut self, a: &Term) {
|
||||
if !self.not_obliv.contains(a) {
|
||||
self.not_obliv.insert(a.clone());
|
||||
self.progress = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn bi_implicate(&mut self, a: &Term, b: &Term) {
|
||||
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
|
||||
(false, true) => {
|
||||
self.not_obliv.insert(a.clone());
|
||||
self.progress = true;
|
||||
}
|
||||
(true, false) => {
|
||||
self.not_obliv.insert(b.clone());
|
||||
self.progress = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
progress: true,
|
||||
not_obliv: TermSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemVisitor for NonOblivComputer {
|
||||
fn visit_eq(&mut self, _orig: &Term, a: &Term, b: &Term) -> Option<Term> {
|
||||
self.bi_implicate(a, b);
|
||||
None
|
||||
}
|
||||
fn visit_ite(&mut self, orig: &Term, _c: &Term, t: &Term, f: &Term) {
|
||||
self.bi_implicate(orig, t);
|
||||
self.bi_implicate(t, f);
|
||||
self.bi_implicate(orig, f);
|
||||
}
|
||||
fn visit_store(&mut self, orig: &Term, a: &Term, k: &Term, _v: &Term) {
|
||||
if let Op::Const(_) = k.op {
|
||||
self.bi_implicate(orig, a);
|
||||
} else {
|
||||
self.mark(a);
|
||||
self.mark(orig);
|
||||
}
|
||||
}
|
||||
fn visit_select(&mut self, _orig: &Term, a: &Term, k: &Term) -> Option<Term> {
|
||||
if let Op::Const(_) = k.op {
|
||||
} else {
|
||||
self.mark(a);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressVisitor for NonOblivComputer {
|
||||
fn check_progress(&self) -> bool {
|
||||
self.progress
|
||||
}
|
||||
fn reset_progress(&mut self) {
|
||||
self.progress = false;
|
||||
}
|
||||
}
|
||||
|
||||
struct Replacer {
|
||||
/// A map from (original) replaced terms, to what they were replaced with.
|
||||
sequences: TermMap<Vec<Term>>,
|
||||
/// The maximum size of arrays that will be replaced.
|
||||
not_obliv: TermSet,
|
||||
}
|
||||
|
||||
impl Replacer {
|
||||
fn should_replace(&self, a: &Term) -> bool {
|
||||
!self.not_obliv.contains(a)
|
||||
}
|
||||
}
|
||||
|
||||
impl MemVisitor for Replacer {
|
||||
fn visit_const_array(&mut self, orig: &Term, _key_sort: &Sort, val: &Term, size: usize) {
|
||||
if self.should_replace(orig) {
|
||||
self.sequences
|
||||
.insert(orig.clone(), repeat(val).cloned().take(size).collect());
|
||||
}
|
||||
}
|
||||
fn visit_eq(&mut self, orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let b_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent eq");
|
||||
let eqs: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Eq; a.clone(), b.clone()])
|
||||
.collect();
|
||||
Some(term(Op::BoolNaryOp(BoolNaryOp::And), eqs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn visit_ite(&mut self, orig: &Term, c: &Term, _t: &Term, _f: &Term) {
|
||||
if self.should_replace(orig) {
|
||||
let a_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent ite");
|
||||
let b_seq = self.sequences.get(&orig.cs[2]).expect("inconsistent ite");
|
||||
let ites: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Ite; c.clone(), a.clone(), b.clone()])
|
||||
.collect();
|
||||
self.sequences.insert(orig.clone(), ites);
|
||||
}
|
||||
}
|
||||
fn visit_store(&mut self, orig: &Term, _a: &Term, k: &Term, v: &Term) {
|
||||
if self.should_replace(orig) {
|
||||
let mut a_seq = self
|
||||
.sequences
|
||||
.get(&orig.cs[0])
|
||||
.expect("inconsistent store")
|
||||
.clone();
|
||||
let k_const = k
|
||||
.as_bv_opt()
|
||||
.expect("not obliv!")
|
||||
.uint()
|
||||
.to_usize()
|
||||
.expect("oversize index");
|
||||
a_seq[k_const] = v.clone();
|
||||
self.sequences.insert(orig.clone(), a_seq);
|
||||
}
|
||||
}
|
||||
fn visit_select(&mut self, orig: &Term, _a: &Term, k: &Term) -> Option<Term> {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let k_const = k
|
||||
.as_bv_opt()
|
||||
.expect("not obliv!")
|
||||
.uint()
|
||||
.to_usize()
|
||||
.expect("oversize index");
|
||||
if k_const < a_seq.len() {
|
||||
Some(a_seq[k_const].clone())
|
||||
} else {
|
||||
panic!("Oversize index: {}", k_const)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
|
||||
if let Sort::Array(_k, v, size) = s {
|
||||
if self.should_replace(orig) {
|
||||
self.sequences.insert(
|
||||
orig.clone(),
|
||||
(0..*size)
|
||||
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
unreachable!("should only visit array vars")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn elim_obliv(t: &Term) -> Term {
|
||||
let mut prop_pass = NonOblivComputer::new();
|
||||
prop_pass.traverse_to_fixpoint(t);
|
||||
let mut replace_pass = Replacer {
|
||||
not_obliv: prop_pass.not_obliv,
|
||||
sequences: TermMap::new(),
|
||||
};
|
||||
replace_pass.traverse(t)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use rug::Integer;
|
||||
|
||||
fn bv(u: usize, w: usize) -> Term {
|
||||
leaf_term(Op::Const(Value::BitVector(BitVector::new(
|
||||
Integer::from(u),
|
||||
w,
|
||||
))))
|
||||
}
|
||||
|
||||
fn v_bv(n: &str, w: usize) -> Term {
|
||||
leaf_term(Op::Var(n.to_owned(), Sort::BitVector(w)))
|
||||
}
|
||||
|
||||
fn array_free(t: &Term) -> bool {
|
||||
for c in PostOrderIter::new(t.clone()) {
|
||||
if let Sort::Array(..) = check(c).unwrap() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn obliv() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv(0, 4)];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
term![Op::Store; z.clone(), bv(3, 4), bv(1, 4)],
|
||||
term![Op::Store; z.clone(), bv(2, 4), bv(1, 4)]
|
||||
],
|
||||
bv(3, 4)
|
||||
];
|
||||
let tt = elim_obliv(&t);
|
||||
assert!(array_free(&tt));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_obliv() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv(0, 4)];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
term![Op::Store; z.clone(), v_bv("a", 4), bv(1, 4)],
|
||||
term![Op::Store; z.clone(), bv(2, 4), bv(1, 4)]
|
||||
],
|
||||
bv(3, 4)
|
||||
];
|
||||
let tt = elim_obliv(&t);
|
||||
assert!(!array_free(&tt));
|
||||
}
|
||||
}
|
||||
@@ -83,3 +83,16 @@ pub trait MemVisitor {
|
||||
cache.remove(node).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProgressVisitor: MemVisitor {
|
||||
fn reset_progress(&mut self);
|
||||
fn check_progress(&self) -> bool;
|
||||
|
||||
fn traverse_to_fixpoint(&mut self, a: &Term) {
|
||||
self.traverse(a);
|
||||
while self.check_progress() {
|
||||
self.reset_progress();
|
||||
self.traverse(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user