IR-based Zokrates front-end (#33)

The ZoKrates front-end now represents ZoK arrays as IR arrays, and ZoK structures as (type-tagged) IR tuples.

During this change, I discovered that IR support for eliminating tuples and arrays was not complete.

Thus the change list is:

    The ZoK front-end uses IR arrays and tuples
    Improve IR passes for array and tuple elimination
    Enforce cargo fmt in CI
    Bugfix: handle ZoK accessors in L-values in the correct order
    Bugfix: add array evaluation to the IR

This PR does not:

    implement an array flattening pass
    implement permutation-based memory-checking

Benefits:

    The ZoK->R1CS compiler is now ~5.88x faster (as defined by the time it takes to run the tests in master's scripts/zokrates_test.zsh script: this goes from 8.59s to 1.46s)
        For benchmarks with multi-dimensional arrays, the ZoK->R1CS compiler can now compile them with reasonable speed. Before it it would time out on even tiny examples.
    The ZoK->R1CS compiler will be able to benefit from future memory-checking improvements
    IR support for arrays and tuples is complete now, making those parts of the IR more accessible to future front-ends.

alex-ozdemir added 21 commits 11 days ago
This commit is contained in:
Alex Ozdemir
2022-01-01 11:44:56 -08:00
committed by GitHub
parent aadd6b7c2d
commit f2744e0c06
56 changed files with 1787 additions and 1106 deletions

View File

@@ -25,8 +25,10 @@ jobs:
- uses: Swatinem/rust-cache@v1
- name: Initialize submodules
run: make init
- name: Check
- name: Typecheck
run: cargo check --verbose
- name: Check format
run: cargo fmt -- --check
- name: Build
run: cargo build --verbose && make build
- name: Run tests

7
.gitignore vendored
View File

@@ -1,10 +1,11 @@
/target
/pf
assignment.txt
params
/assignment.txt
/params
__pycache__
/P
/V
/pi
/x
perf.data*
/perf.data*
/.gdb_history

12
Cargo.lock generated
View File

@@ -217,6 +217,7 @@ dependencies = [
"good_lp",
"hashconsing",
"ieee754",
"itertools 0.10.3",
"lazy_static",
"log",
"lp-solvers",
@@ -531,6 +532,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.8"
@@ -657,7 +667,7 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fbf404899169771dd6a32c84248b83cd67a26cc7cc957aac87661490e1227e4"
dependencies = [
"itertools",
"itertools 0.7.11",
"proc-macro2 0.4.30",
"quote 0.6.13",
"single",

View File

@@ -32,6 +32,7 @@ pest = "2.1"
pest_derive = "2.1"
pest-ast = "0.3"
from-pest = "0.3"
itertools = "0.10"
[dev-dependencies]
quickcheck = "1"

View File

@@ -4,7 +4,7 @@ build: init
cargo build --release --example circ && ./scripts/build_mpc_zokrates_test.zsh && ./scripts/build_aby.zsh
test:
cargo test && ./scripts/zokrates_test.zsh && python3 ./scripts/test_aby.py && ./scripts/test_zok_to_ilp.zsh && ./scripts/test_zok_to_ilp_pf.zsh ./scripts/test_atalog.zsh
cargo test && ./scripts/zokrates_test.zsh && python3 ./scripts/test_aby.py && ./scripts/test_zok_to_ilp.zsh && ./scripts/test_zok_to_ilp_pf.zsh && ./scripts/test_datalog.zsh
init:
git submodule update --init
@@ -18,3 +18,6 @@ clean:
touch ./third_party/ABY/src/examples/2pc_* && rm -r -- ./third_party/ABY/src/examples/2pc_*
sed '/add_subdirectory.*2pc.*/d' -i ./third_party/ABY/src/examples/CMakeLists.txt
rm -rf P V pi perf.data perf.data.old flamegraph.svg
format:
cargo fmt --all

6
TODO.md Normal file
View File

@@ -0,0 +1,6 @@
Passes to write:
[ ] shrink bit-vectors using range analysis.
[ ] common sub-expression grouping
* for commutative/associative ops?
* after flattening

View File

@@ -0,0 +1,14 @@
struct Pt {
field x
field y
}
struct Pts {
Pt[2] pts
}
def main(private field y) -> field:
Pt p1 = Pt {x: 2, y: y}
Pt p2 = Pt {x: y, y: 2}
Pts[1] pts = [Pts { pts: [p1, p2] }]
return pts[0].pts[0].y * pts[0].pts[1].x

View File

@@ -0,0 +1 @@
y 4

View File

@@ -0,0 +1 @@
16

View File

@@ -0,0 +1,3 @@
// Making sure we get input order right
def main(public u16 a, public u16 b, public u16 c, public u16 d) -> u16:
return a ^ b ^ c ^ d

View File

@@ -0,0 +1,4 @@
a 1
b 2
c 3
d 4

View File

@@ -0,0 +1,5 @@
1
2
3
4
4

View File

@@ -0,0 +1,12 @@
def main(private field[2][2] A, private field[2][2] B) -> field[2][2]:
field [2][2] AB = [[0; 2]; 2]
for field i in 0..2 do
for field j in 0..2 do
for field k in 0..2 do
AB[i][j] = AB[i][j] + A[i][k] * B[k][j]
endfor
endfor
endfor
return AB

View File

@@ -0,0 +1,8 @@
A.0.0 1
A.0.1 0
A.1.0 0
A.1.1 1
B.0.0 1
B.0.1 0
B.1.0 0
B.1.1 1

View File

@@ -0,0 +1,4 @@
1
0
0
1

View File

@@ -0,0 +1,2 @@
def main(private field x, private field y)-> field:
return x * y

View File

@@ -0,0 +1,2 @@
x 4
y 5

View File

@@ -0,0 +1 @@
20

View File

@@ -0,0 +1,12 @@
struct Pt {
field x
field y
}
struct Pts {
Pt[2] pts
}
def main(field y) -> field:
Pt p = Pt {x: 2, y: y}
Pts pts = Pts { pts: [p, p] }
return pts.pts[0].y + pts.pts[1].x

View File

@@ -0,0 +1 @@
y 6

View File

@@ -0,0 +1,2 @@
6
8

View File

@@ -0,0 +1,10 @@
struct Pt {
field x
field y
}
struct PtWr {
Pt p
}
def main(field x, field y) -> field:
PtWr p = PtWr { p: Pt { x: x, y: y } }
return p.p.x * p.p.y

View File

@@ -0,0 +1,2 @@
x 5
y 6

View File

@@ -0,0 +1,3 @@
5
6
30

View File

@@ -0,0 +1,13 @@
struct Pt {
field x
field y
}
struct Pts {
Pt[2] pts
}
def main(private field y, private field i, private field j, private field k) -> field:
Pt p = Pt {x: y, y: y}
Pts[1] pts = [Pts { pts: [p, p] }]
return pts[i].pts[j].y * pts[i].pts[j].x

View File

@@ -0,0 +1,4 @@
y 6
i 0
j 0
k 1

View File

@@ -0,0 +1 @@
36

View File

@@ -177,24 +177,31 @@ fn main() {
}
};
let cs = match mode {
Mode::Opt => opt(cs, vec![Opt::ConstantFold]),
Mode::Opt => opt(cs, vec![Opt::ScalarizeVars, Opt::ConstantFold]),
Mode::Mpc(_) => opt(
cs,
vec![],
vec![Opt::ScalarizeVars],
// vec![Opt::Sha, Opt::ConstantFold, Opt::Mem, Opt::ConstantFold],
),
Mode::Proof | Mode::ProofOfHighValue(_) => opt(
cs,
vec![
Opt::ScalarizeVars,
Opt::Flatten,
Opt::Sha,
Opt::ConstantFold,
Opt::Flatten,
//Opt::FlattenAssertions,
Opt::Inline,
Opt::Mem,
// Tuples must be eliminated before oblivious array elim
Opt::Tuple,
Opt::ConstantFold,
Opt::Obliv,
// The obliv elim pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::LinearScan,
// The linear scan pass produces more tuples, that must be eliminated
Opt::Tuple,
Opt::Flatten,
//Opt::FlattenAssertions,
Opt::ConstantFold,
Opt::Inline,
],
@@ -215,12 +222,12 @@ fn main() {
let r1cs = to_r1cs(cs, circ::front::zokrates::ZOKRATES_MODULUS.clone());
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
let r1cs = reduce_linearities(r1cs);
println!("Final R1cs size: {}", r1cs.constraints().len());
match action {
ProofAction::Count => {
println!("Final R1cs size: {}", r1cs.constraints().len());
}
ProofAction::Count => (),
ProofAction::Prove => {
println!("Proving");
r1cs.check_all();
let rng = &mut rand::thread_rng();
let mut pk_file = File::open(prover_key).unwrap();
let pk = Parameters::<Bls12>::read(&mut pk_file, false).unwrap();

View File

@@ -2,6 +2,6 @@
set -ex
cargo flamegraph --help || (echo "Please install the rust 'flamegraph' binary with 'cargo install flamegraph'" && exit 1)
(cargo flamegraph --help > /dev/null) || (echo "Please install the rust 'flamegraph' binary with 'cargo install flamegraph'" && exit 1)
cargo flamegraph --example circ third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok r1cs --action count

View File

@@ -15,10 +15,14 @@ $BIN --language datalog ./examples/datalog/arr.pl r1cs --action count || true
# Small R1cs b/c too little recursion.
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 4 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
[ "$size" -lt 10 ]
# Big R1cs b/c enough recursion
($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 5 r1cs --action count || true) | egrep "Final R1cs size: 356"
($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 10 r1cs --action count || true) | egrep "Final R1cs size: 356"
($BIN --language datalog ./examples/datalog/dec.pl -r 2 r1cs --action count || true) | egrep "Final R1cs size: 356"
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 5 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 10 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
[ "$size" -gt 250 ]
size=$(($BIN --language datalog ./examples/datalog/dec.pl -r 2 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
[ "$size" -gt 250 ]
# Test prim-rec test
$BIN --language datalog ./examples/datalog/dec.pl --lint-prim-rec smt

View File

@@ -5,7 +5,9 @@ set -ex
disable -r time
cargo build --release --example circ
#cargo build --example circ
#BIN=./target/debug/examples/circ
BIN=./target/release/examples/circ
case "$OSTYPE" in
@@ -22,6 +24,15 @@ function r1cs_test {
measure_time $BIN $zpath r1cs --action count
}
# Test prove workflow, given an example name
function pf_test {
ex_name=$1
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
$BIN --inputs examples/ZoKrates/pf/$ex_name.zok.in examples/ZoKrates/pf/$ex_name.zok r1cs --action prove
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --instance examples/ZoKrates/pf/$ex_name.zok.x --action verify
rm -rf P V pi
}
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok
@@ -36,15 +47,12 @@ r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zo
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok
# Test prove workflow, given an example name
function pf_test {
ex_name=$1
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
$BIN --inputs examples/ZoKrates/pf/$ex_name.zok.in examples/ZoKrates/pf/$ex_name.zok r1cs --action prove
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --instance examples/ZoKrates/pf/$ex_name.zok.x --action verify
rm -rf P V pi
}
pf_test 3_plus
pf_test xor
pf_test mul
pf_test many_pub
pf_test str_str
pf_test str_arr_str
pf_test arr_str_arr_str
pf_test var_idx_arr_str_arr_str
pf_test mm

View File

@@ -127,17 +127,12 @@ impl MemManager {
///
/// Returns a (concrete) allocation identifier which can be used to access this allocation.
pub fn zero_allocate(&mut self, size: usize, addr_width: usize, val_width: usize) -> AllocId {
let sort = Sort::Array(
Box::new(Sort::BitVector(addr_width)),
Box::new(Sort::BitVector(val_width)),
size,
);
let array = Value::Array(
sort,
let array = Value::Array(Array::new(
Sort::BitVector(addr_width),
Box::new(Value::BitVector(BitVector::zeros(val_width))),
BTreeMap::new(),
size,
);
));
self.allocate(leaf_term(Op::Const(array)))
}

View File

@@ -2,11 +2,11 @@
use thiserror::Error;
use super::term::T;
use super::parser::ast::Span;
use super::term::T;
use std::fmt::{Display, Formatter, self};
use std::convert::From;
use std::fmt::{self, Display, Formatter};
#[derive(Error, Debug)]
/// An error in circuit translation

View File

@@ -1,8 +1,8 @@
//! Datalog parser
#![allow(missing_docs)]
use pest::Parser;
use pest::error::Error;
use pest::Parser;
use pest_derive::Parser;
// Issue with the proc macro
@@ -193,7 +193,6 @@ pub mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct BinaryExpression<'ast> {
pub op: BinaryOperator,
@@ -314,17 +313,20 @@ pub mod ast {
match next.as_rule() {
// this happens when we have an expression in parentheses: it needs to be processed as another sequence of terms and operators
Rule::paren_expr => Expression::Paren(
Box::new(Expression::from_pest(
Box::new(
Expression::from_pest(
&mut pair.into_inner().next().unwrap().into_inner(),
).unwrap()),
)
.unwrap(),
),
next.as_span(),
),
Rule::literal => {
Expression::Literal(Literal::from_pest(&mut pair.into_inner()).unwrap())
}
Rule::identifier => Expression::Identifier(
Ident::from_pest(&mut pair.into_inner()).unwrap(),
),
Rule::identifier => {
Expression::Identifier(Ident::from_pest(&mut pair.into_inner()).unwrap())
}
Rule::unary_expression => {
let span = next.as_span();
let mut inner = next.into_inner();
@@ -345,9 +347,9 @@ pub mod ast {
span,
})
}
Rule::call_expr => Expression::Call(
CallExpression::from_pest(&mut pair.into_inner()).unwrap(),
),
Rule::call_expr => {
Expression::Call(CallExpression::from_pest(&mut pair.into_inner()).unwrap())
}
Rule::access_expr => Expression::Access(
AccessExpression::from_pest(&mut pair.into_inner()).unwrap(),
),

View File

@@ -9,7 +9,7 @@ use super::error::ErrorKind;
use super::ty::Ty;
use crate::circify::{CirCtx, Embeddable};
use crate::front::zokrates::ZOKRATES_MODULUS_ARC;
use crate::front::zokrates::{ZOKRATES_MODULUS_ARC, ZOK_FIELD_SORT};
use crate::ir::term::*;
/// A term
@@ -81,19 +81,22 @@ pub fn uint_lit(v: u64, w: u8) -> T {
}
impl Ty {
fn default(&self) -> T {
T::new(self.default_ir(), self.clone())
}
fn default_ir(&self) -> Term {
fn sort(&self) -> Sort {
match self {
Self::Bool => leaf_term(Op::Const(Value::Bool(false))),
Self::Uint(w) => bv_lit(0, *w as usize),
Self::Field => pf_ir_lit(0),
Self::Array(l, t) => {
term![Op::ConstArray(Sort::Field(ZOKRATES_MODULUS_ARC.clone()), *l); t.default_ir()]
Self::Bool => Sort::Bool,
Self::Uint(w) => Sort::BitVector(*w as usize),
Self::Field => ZOK_FIELD_SORT.clone(),
Self::Array(n, b) => {
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
}
}
}
fn default_ir_term(&self) -> Term {
self.sort().default_term()
}
fn default(&self) -> T {
T::new(self.default_ir_term(), self.clone())
}
}
impl Display for T {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
@@ -397,7 +400,7 @@ impl Embeddable for Datalog {
})
.enumerate()
.fold(
ty.default_ir(),
ty.default_ir_term(),
|arr, (i, v)| term![Op::Store; arr, pf_ir_lit(i), v.ir],
),
ty.clone(),

View File

@@ -1,6 +1,6 @@
//! Types in our datalog variant
use std::fmt::{Display, Formatter, self};
use std::fmt::{self, Display, Formatter};
/// A type
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]

View File

@@ -6,6 +6,7 @@ mod term;
use super::FrontEnd;
use crate::circify::{Circify, Loc, Val};
use crate::ir::proof::{self, ConstraintMetadata};
use crate::ir::term::extras::Letified;
use crate::ir::term::*;
use log::debug;
use rug::Integer;
@@ -21,6 +22,8 @@ use term::*;
pub use term::ZOKRATES_MODULUS;
/// The modulus for the ZoKrates language.
pub use term::ZOKRATES_MODULUS_ARC;
/// The modulus for the ZoKrates language.
pub use term::ZOK_FIELD_SORT;
/// The prover visibility
pub const PROVER_VIS: Option<PartyId> = Some(proof::PROVER_ID);
@@ -99,18 +102,28 @@ struct ZGen<'ast> {
mode: Mode,
}
enum ZLoc {
Var(Loc),
Member(Box<ZLoc>, String),
Idx(Box<ZLoc>, T),
struct ZLoc {
var: Loc,
accesses: Vec<ZAccess>,
}
impl ZLoc {
fn loc(&self) -> &Loc {
match self {
ZLoc::Var(l) => l,
ZLoc::Member(i, _) => i.loc(),
ZLoc::Idx(i, _) => i.loc(),
enum ZAccess {
Member(String),
Idx(T),
}
fn loc_store(struct_: T, loc: &[ZAccess], val: T) -> Result<T, String> {
match loc.first() {
None => Ok(val),
Some(ZAccess::Member(field)) => {
let inner = field_select(&struct_, &field)?;
let new_inner = loc_store(inner, &loc[1..], val)?;
field_store(struct_, &field, new_inner)
}
Some(ZAccess::Idx(idx)) => {
let old_inner = array_select(struct_.clone(), idx.clone())?;
let new_inner = loc_store(old_inner, &loc[1..], val)?;
array_store(struct_, idx.clone(), new_inner)
}
}
}
@@ -191,17 +204,16 @@ impl<'ast> ZGen<'ast> {
self.unwrap(decl_res, &i.index.span);
for j in s..e {
self.circ.enter_scope();
let ass_res = self
.circ
.assign(Loc::local(v_name.clone()), Val::Term(
match ty {
Ty::Uint(8) => T::Uint(8, bv_lit(j, 8)),
Ty::Uint(16) => T::Uint(16, bv_lit(j, 16)),
Ty::Uint(32) => T::Uint(32, bv_lit(j, 32)),
Ty::Field => T::Field(pf_lit(j)),
_ => panic!("Unexpected type for iteration: {:?}", ty),
}
));
let ass_res = self.circ.assign(
Loc::local(v_name.clone()),
Val::Term(match ty {
Ty::Uint(8) => uint_lit(j, 8),
Ty::Uint(16) => uint_lit(j, 16),
Ty::Uint(32) => uint_lit(j, 32),
Ty::Field => field_lit(j),
_ => panic!("Unexpected type for iteration: {:?}", ty),
}),
);
self.unwrap(ass_res, &i.index.span);
for s in &i.statements {
self.stmt(s);
@@ -217,7 +229,7 @@ impl<'ast> ZGen<'ast> {
let ty = e.type_();
if let Some(t) = l.ty.as_ref() {
let decl_ty = self.type_(t);
if decl_ty != ty {
if &decl_ty != ty {
self.err(
format!(
"Assignment type mismatch: {} annotated vs {} actual",
@@ -242,79 +254,60 @@ impl<'ast> ZGen<'ast> {
}
}
fn apply_lval_mod(&mut self, base: T, loc: ZLoc, val: T) -> Result<T, String> {
match loc {
ZLoc::Var(_) => Ok(val),
ZLoc::Member(inner_loc, field) => {
let old_inner = field_select(&base, &field)?;
let new_inner = self.apply_lval_mod(old_inner, *inner_loc, val)?;
field_store(base, &field, new_inner)
}
ZLoc::Idx(inner_loc, idx) => {
let old_inner = array_select(base.clone(), idx.clone())?;
let new_inner = self.apply_lval_mod(old_inner, *inner_loc, val)?;
array_store(base, idx, new_inner)
}
}
}
fn mod_lval(&mut self, l: ZLoc, t: T) -> Result<(), String> {
let var = l.loc().clone();
fn mod_lval(&mut self, loc: ZLoc, val: T) -> Result<(), String> {
let old = self
.circ
.get_value(var.clone())
.get_value(loc.var.clone())
.map_err(|e| format!("{}", e))?
.unwrap_term();
let new = self.apply_lval_mod(old, l, t)?;
let new = loc_store(old, &loc.accesses, val)?;
debug!("Assign: {:?} = {}", loc.var, Letified(new.term.clone()));
self.circ
.assign(var, Val::Term(new))
.assign(loc.var, Val::Term(new))
.map_err(|e| format!("{}", e))
.map(|_| ())
}
fn lval(&mut self, l: &ast::Assignee<'ast>) -> ZLoc {
l.accesses.iter().fold(
ZLoc::Var(Loc::local(l.id.value.clone())),
|inner, acc| match acc {
ast::AssigneeAccess::Member(m) => ZLoc::Member(Box::new(inner), m.id.value.clone()),
ast::AssigneeAccess::Select(m) => {
let i = if let ast::RangeOrExpression::Expression(e) = &m.expression {
let mut loc = ZLoc {
var: Loc::local(l.id.value.clone()),
accesses: vec![],
};
for acc in &l.accesses {
loc.accesses.push(match acc {
ast::AssigneeAccess::Member(m) => ZAccess::Member(m.id.value.clone()),
ast::AssigneeAccess::Select(m) => ZAccess::Idx(
if let ast::RangeOrExpression::Expression(e) = &m.expression {
self.expr(&e)
} else {
panic!("Cannot assign to slice")
};
ZLoc::Idx(Box::new(inner), i)
}
},
)
},
),
})
}
loc
}
fn const_(&mut self, e: &ast::ConstantExpression<'ast>) -> T {
match e {
ast::ConstantExpression::U8(u) => {
T::Uint(8, bv_lit(u8::from_str_radix(&u.value[2..], 16).unwrap(), 8))
uint_lit(u8::from_str_radix(&u.value[2..], 16).unwrap(), 8)
}
ast::ConstantExpression::U16(u) => {
uint_lit(u16::from_str_radix(&u.value[2..], 16).unwrap(), 16)
}
ast::ConstantExpression::U32(u) => {
uint_lit(u32::from_str_radix(&u.value[2..], 16).unwrap(), 32)
}
ast::ConstantExpression::U16(u) => T::Uint(
16,
bv_lit(u16::from_str_radix(&u.value[2..], 16).unwrap(), 16),
),
ast::ConstantExpression::U32(u) => T::Uint(
32,
bv_lit(u32::from_str_radix(&u.value[2..], 16).unwrap(), 32),
),
ast::ConstantExpression::DecimalNumber(u) => {
T::Field(pf_lit(Integer::from_str_radix(&u.value, 10).unwrap()))
field_lit(Integer::from_str_radix(&u.value, 10).unwrap())
}
ast::ConstantExpression::BooleanLiteral(u) => {
Self::const_bool(bool::from_str(&u.value).unwrap())
z_bool_lit(bool::from_str(&u.value).unwrap())
}
}
}
fn const_bool(b: bool) -> T {
T::Bool(leaf_term(Op::Const(Value::Bool(b))))
}
fn bin_op(&self, o: &ast::BinaryOperator) -> fn(T, T) -> Result<T, String> {
match o {
ast::BinaryOperator::BitXor => bitxor,
@@ -364,7 +357,7 @@ impl<'ast> ZGen<'ast> {
.flat_map(|x| self.array_lit_elem(x))
.collect(),
),
ast::Expression::InlineStruct(u) => Ok(T::Struct(
ast::Expression::InlineStruct(u) => Ok(T::new_struct(
u.ty.value.clone(),
u.members
.iter()
@@ -373,12 +366,11 @@ impl<'ast> ZGen<'ast> {
)),
ast::Expression::ArrayInitializer(a) => {
let v = self.expr(&a.value);
let ty = v.type_();
let n = const_int(self.const_(&a.count))
.unwrap()
.to_usize()
.unwrap();
Ok(T::Array(ty, vec![v; n]))
array(vec![v; n])
}
ast::Expression::Postfix(p) => {
// Assume no functions in arrays, etc.
@@ -417,7 +409,7 @@ impl<'ast> ZGen<'ast> {
.circ
.exit_fn()
.map(|a| a.unwrap_term())
.unwrap_or_else(|| Self::const_bool(false));
.unwrap_or_else(|| z_bool_lit(false));
self.file_stack.pop();
ret
};
@@ -510,9 +502,7 @@ impl<'ast> ZGen<'ast> {
let t = ret_terms.into_iter().next().unwrap();
match check(&t) {
Sort::BitVector(_) => {}
s => {
panic!("Cannot maximize output of type {}", s)
}
s => panic!("Cannot maximize output of type {}", s),
}
self.circ.cir_ctx().cs.borrow_mut().outputs.push(t);
}
@@ -525,12 +515,8 @@ impl<'ast> ZGen<'ast> {
);
let t = ret_terms.into_iter().next().unwrap();
let cmp = match check(&t) {
Sort::BitVector(w) => {
term![BV_UGE; t, bv_lit(v, w)]
}
s => {
panic!("Cannot maximize output of type {}", s)
}
Sort::BitVector(w) => term![BV_UGE; t, bv_lit(v, w)],
s => panic!("Cannot maximize output of type {}", s),
};
self.circ.cir_ctx().cs.borrow_mut().outputs.push(cmp);
}
@@ -669,13 +655,12 @@ impl<'ast> ZGen<'ast> {
);
}
for s in &f.structs {
let ty = Ty::Struct(
let ty = Ty::new_struct(
s.id.value.clone(),
s.fields
.clone()
.iter()
.map(|f| (f.id.value.clone(), self.type_(&f.ty)))
.collect(),
.map(|f| (f.id.value.clone(), self.type_(&f.ty))),
);
debug!("struct {}", s.id.value);
self.circ.def_type(&s.id.value, ty);

View File

@@ -19,6 +19,8 @@ lazy_static! {
.unwrap();
/// The modulus for ZoKrates, as an ARC
pub static ref ZOKRATES_MODULUS_ARC: Arc<Integer> = Arc::new(ZOKRATES_MODULUS.clone());
/// The modulus for ZoKrates, as an IR sort
pub static ref ZOK_FIELD_SORT: Sort = Sort::Field(ZOKRATES_MODULUS_ARC.clone());
}
#[derive(Clone, PartialEq, Eq)]
@@ -26,17 +28,57 @@ pub enum Ty {
Uint(usize),
Bool,
Field,
Struct(String, BTreeMap<String, Ty>),
Struct(String, FieldList<Ty>),
Array(usize, Box<Ty>),
}
pub use field_list::FieldList;
/// This module contains [FieldList].
///
/// It gets its own module so that its member can be private.
mod field_list {
#[derive(Clone, PartialEq, Eq)]
pub struct FieldList<T> {
// must be kept in sorted order
list: Vec<(String, T)>,
}
impl<T> FieldList<T> {
pub fn new(mut list: Vec<(String, T)>) -> Self {
list.sort_by_cached_key(|p| p.0.clone());
FieldList { list }
}
pub fn search(&self, key: &str) -> Option<(usize, &T)> {
let idx = self
.list
.binary_search_by_key(&key, |p| p.0.as_str())
.ok()?;
Some((idx, &self.list[idx].1))
}
pub fn get(&self, idx: usize) -> (&str, &T) {
(&self.list[idx].0, &self.list[idx].1)
}
pub fn fields(&self) -> impl Iterator<Item = &(String, T)> {
self.list.iter()
}
}
}
impl Display for Ty {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Ty::Bool => write!(f, "bool"),
Ty::Uint(w) => write!(f, "u{}", w),
Ty::Field => write!(f, "field"),
Ty::Struct(n, _) => write!(f, "{}", n),
Ty::Struct(n, fields) => {
let mut o = f.debug_struct(n);
for (f_name, f_ty) in fields.fields() {
o.field(f_name, f_ty);
}
o.finish()
}
Ty::Array(n, b) => write!(f, "{}[{}]", b, n),
}
}
@@ -49,88 +91,113 @@ impl fmt::Debug for Ty {
}
impl Ty {
fn default(&self) -> T {
fn sort(&self) -> Sort {
match self {
Self::Bool => T::Bool(leaf_term(Op::Const(Value::Bool(false)))),
Self::Uint(w) => T::Uint(*w, bv_lit(0, *w)),
Self::Field => T::Field(pf_lit(0)),
Self::Array(n, b) => T::Array((**b).clone(), vec![b.default(); *n]),
Self::Struct(n, fs) => T::Struct(
n.clone(),
fs.iter()
.map(|(f_name, f_ty)| (f_name.to_owned(), f_ty.default()))
.collect(),
),
Self::Bool => Sort::Bool,
Self::Uint(w) => Sort::BitVector(*w),
Self::Field => ZOK_FIELD_SORT.clone(),
Self::Array(n, b) => {
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
}
Self::Struct(_name, fs) => {
Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect())
}
}
}
fn default_ir_term(&self) -> Term {
self.sort().default_term()
}
fn default(&self) -> T {
T {
term: self.default_ir_term(),
ty: self.clone(),
}
}
/// Creates a new structure type, sorting the keys.
pub fn new_struct<I: IntoIterator<Item = (String, Ty)>>(name: String, fields: I) -> Self {
Self::Struct(name, FieldList::new(fields.into_iter().collect()))
}
}
#[derive(Clone)]
pub enum T {
Uint(usize, Term),
Bool(Term),
Field(Term),
/// TODO: special case primitive arrays with Vec<T>.
Array(Ty, Vec<T>),
Struct(String, BTreeMap<String, T>),
#[derive(Clone, Debug)]
pub struct T {
pub ty: Ty,
pub term: Term,
}
impl T {
pub fn type_(&self) -> Ty {
match self {
T::Uint(w, _) => Ty::Uint(*w),
T::Bool(_) => Ty::Bool,
T::Field(_) => Ty::Field,
T::Array(b, v) => Ty::Array(v.len(), Box::new(b.clone())),
T::Struct(name, map) => Ty::Struct(
name.clone(),
map.iter()
.map(|(f_name, f_term)| (f_name.clone(), f_term.type_()))
.collect(),
),
}
pub fn new(ty: Ty, term: Term) -> Self {
Self { ty, term }
}
pub fn type_(&self) -> &Ty {
&self.ty
}
/// Get all IR terms inside this value, as a list.
pub fn terms(&self) -> Vec<Term> {
let mut output: Vec<Term> = Vec::new();
fn terms_tail(term: &T, output: &mut Vec<Term>) {
match term {
T::Bool(b) => output.push(b.clone()),
T::Uint(_, b) => output.push(b.clone()),
T::Field(b) => output.push(b.clone()),
T::Array(_, v) => v.iter().for_each(|v| terms_tail(v, output)),
T::Struct(_, map) => map.iter().for_each(|(_, v)| terms_tail(v, output)),
fn terms_tail(term: &Term, output: &mut Vec<Term>) {
match check(term) {
Sort::Bool | Sort::BitVector(_) | Sort::Field(_) => output.push(term.clone()),
Sort::Array(_k, _v, size) => {
for i in 0..size {
terms_tail(&term![Op::Select; term.clone(), pf_lit_ir(i)], output)
}
}
Sort::Tuple(sorts) => {
for i in 0..sorts.len() {
terms_tail(&term![Op::Field(i); term.clone()], output)
}
}
s => unreachable!("Unreachable IR sort {} in ZoK", s),
}
}
terms_tail(self, &mut output);
terms_tail(&self.term, &mut output);
output
}
fn unwrap_array_ir(self) -> Result<Vec<Term>, String> {
match &self.ty {
Ty::Array(size, _sort) => Ok((0..*size)
.map(|i| term![Op::Select; self.term.clone(), pf_lit_ir(i)])
.collect()),
s => Err(format!("Not an array: {}", s)),
}
}
pub fn unwrap_array(self) -> Result<Vec<T>, String> {
match self {
T::Array(_, v) => Ok(v),
match &self.ty {
Ty::Array(_size, sort) => {
let sort = (**sort).clone();
Ok(self
.unwrap_array_ir()?
.into_iter()
.map(|t| T::new(sort.clone(), t))
.collect())
}
s => Err(format!("Not an array: {}", s)),
}
}
pub fn new_array(v: Vec<T>) -> Result<T, String> {
array(v)
}
pub fn new_struct(name: String, fields: Vec<(String, T)>) -> T {
let (field_tys, ir_terms): (Vec<_>, Vec<_>) = fields
.into_iter()
.map(|(name, t)| ((name.clone(), t.ty), (name, t.term)))
.unzip();
let field_ty_list = FieldList::new(field_tys);
let ir_term = term(Op::Tuple, {
let with_indices: BTreeMap<usize, Term> = ir_terms
.into_iter()
.map(|(name, t)| (field_ty_list.search(&name).unwrap().0, t))
.collect();
with_indices.into_iter().map(|(_i, t)| t).collect()
});
T::new(Ty::Struct(name, field_ty_list), ir_term)
}
}
impl Display for T {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
T::Bool(x) => write!(f, "Bool({})", x),
T::Uint(_, x) => write!(f, "Uint({})", x),
T::Field(x) => write!(f, "Field({})", x),
T::Struct(_, _) => write!(f, "struct"),
T::Array(_, _) => write!(f, "array"),
}
}
}
impl fmt::Debug for T {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self)
write!(f, "{}", self.term)
}
}
@@ -142,10 +209,16 @@ fn wrap_bin_op(
a: T,
b: T,
) -> Result<T, String> {
match (a, b, fu, ff, fb) {
(T::Uint(na, a), T::Uint(nb, b), Some(fu), _, _) if na == nb => Ok(T::Uint(na, fu(a, b))),
(T::Bool(a), T::Bool(b), _, _, Some(fb)) => Ok(T::Bool(fb(a, b))),
(T::Field(a), T::Field(b), _, Some(ff), _) => Ok(T::Field(ff(a, b))),
match (&a.ty, &b.ty, fu, ff, fb) {
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _) if na == nb => {
Ok(T::new(Ty::Uint(*na), fu(a.term.clone(), b.term.clone())))
}
(Ty::Bool, Ty::Bool, _, _, Some(fb)) => {
Ok(T::new(Ty::Bool, fb(a.term.clone(), b.term.clone())))
}
(Ty::Field, Ty::Field, _, Some(ff), _) => {
Ok(T::new(Ty::Field, ff(a.term.clone(), b.term.clone())))
}
(x, y, _, _, _) => Err(format!("Cannot perform op '{}' on {} and {}", name, x, y)),
}
}
@@ -158,10 +231,16 @@ fn wrap_bin_pred(
a: T,
b: T,
) -> Result<T, String> {
match (a, b, fu, ff, fb) {
(T::Uint(na, a), T::Uint(nb, b), Some(fu), _, _) if na == nb => Ok(T::Bool(fu(a, b))),
(T::Bool(a), T::Bool(b), _, _, Some(fb)) => Ok(T::Bool(fb(a, b))),
(T::Field(a), T::Field(b), _, Some(ff), _) => Ok(T::Bool(ff(a, b))),
match (&a.ty, &b.ty, fu, ff, fb) {
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _) if na == nb => {
Ok(T::new(Ty::Bool, fu(a.term.clone(), b.term.clone())))
}
(Ty::Bool, Ty::Bool, _, _, Some(fb)) => {
Ok(T::new(Ty::Bool, fb(a.term.clone(), b.term.clone())))
}
(Ty::Field, Ty::Field, _, Some(ff), _) => {
Ok(T::new(Ty::Bool, ff(a.term.clone(), b.term.clone())))
}
(x, y, _, _, _) => Err(format!("Cannot perform op '{}' on {} and {}", name, x, y)),
}
}
@@ -317,10 +396,10 @@ fn wrap_un_op(
fb: Option<fn(Term) -> Term>,
a: T,
) -> Result<T, String> {
match (a, fu, ff, fb) {
(T::Uint(na, a), Some(fu), _, _) => Ok(T::Uint(na, fu(a))),
(T::Bool(a), _, _, Some(fb)) => Ok(T::Bool(fb(a))),
(T::Field(a), _, Some(ff), _) => Ok(T::Field(ff(a))),
match (&a.ty, fu, ff, fb) {
(Ty::Uint(_), Some(fu), _, _) => Ok(T::new(a.ty.clone(), fu(a.term.clone()))),
(Ty::Bool, _, _, Some(fb)) => Ok(T::new(Ty::Bool, fb(a.term.clone()))),
(Ty::Field, _, Some(ff), _) => Ok(T::new(Ty::Field, ff(a.term.clone()))),
(x, _, _, _) => Err(format!("Cannot perform op '{}' on {}", name, x)),
}
}
@@ -352,31 +431,25 @@ pub fn not(a: T) -> Result<T, String> {
}
pub fn const_int(a: T) -> Result<Integer, String> {
let s = match &a {
T::Field(b) => match &b.op {
Op::Const(Value::Field(f)) => Some(f.i().clone()),
_ => None,
},
T::Uint(_, i) => match &i.op {
Op::Const(Value::BitVector(f)) => Some(f.uint().clone()),
_ => None,
},
match &a.term.op {
Op::Const(Value::Field(f)) => Some(f.i().clone()),
Op::Const(Value::BitVector(f)) => Some(f.uint().clone()),
_ => None,
};
s.ok_or_else(|| format!("{} is not a constant integer", a))
}
.ok_or_else(|| format!("{} is not a constant integer", a))
}
pub fn bool(a: T) -> Result<Term, String> {
match a {
T::Bool(b) => Ok(b),
match &a.ty {
Ty::Bool => Ok(a.term),
a => Err(format!("{} is not a boolean", a)),
}
}
fn wrap_shift(name: &str, op: BvBinOp, a: T, b: T) -> Result<T, String> {
let bc = const_int(b)?;
match a {
T::Uint(na, a) => Ok(T::Uint(na, term![Op::BvBinOp(op); a, bv_lit(bc, na)])),
match &a.ty {
&Ty::Uint(na) => Ok(T::new(a.ty, term![Op::BvBinOp(op); a.term, bv_lit(bc, na)])),
x => Err(format!("Cannot perform op '{}' on {} and {}", name, x, bc)),
}
}
@@ -390,30 +463,10 @@ pub fn shr(a: T, b: T) -> Result<T, String> {
}
fn ite(c: Term, a: T, b: T) -> Result<T, String> {
match (a, b) {
(T::Uint(na, a), T::Uint(nb, b)) if na == nb => Ok(T::Uint(na, term![Op::Ite; c, a, b])),
(T::Bool(a), T::Bool(b)) => Ok(T::Bool(term![Op::Ite; c, a, b])),
(T::Field(a), T::Field(b)) => Ok(T::Field(term![Op::Ite; c, a, b])),
(T::Array(ta, a), T::Array(tb, b)) if a.len() == b.len() && ta == tb => Ok(T::Array(
ta,
a.into_iter()
.zip(b.into_iter())
.map(|(a_i, b_i)| ite(c.clone(), a_i, b_i))
.collect::<Result<Vec<_>, _>>()?,
)),
(T::Struct(na, a), T::Struct(nb, b)) if na == nb => Ok(T::Struct(na.clone(), {
a.into_iter()
.zip(b.into_iter())
.map(|((af, av), (bf, bv))| {
if af == bf {
Ok((af, ite(c.clone(), av, bv)?))
} else {
Err(format!("Field mismatch: {} vs {}", af, bf))
}
})
.collect::<Result<BTreeMap<_, _>, String>>()?
})),
(x, y) => Err(format!("Cannot perform ITE on {} and {}", x, y)),
if &a.ty != &b.ty {
Err(format!("Cannot perform ITE on {} and {}", a, b))
} else {
Ok(T::new(a.ty.clone(), term![Op::Ite; c, a.term, b.term]))
}
}
@@ -421,7 +474,7 @@ pub fn cond(c: T, a: T, b: T) -> Result<T, String> {
ite(bool(c)?, a, b)
}
pub fn pf_lit<I>(i: I) -> Term
pub fn pf_lit_ir<I>(i: I) -> Term
where
Integer: From<I>,
{
@@ -431,76 +484,111 @@ where
))))
}
pub fn slice(array: T, start: Option<usize>, end: Option<usize>) -> Result<T, String> {
match array {
T::Array(b, mut list) => {
pub fn field_lit<I>(i: I) -> T
where
Integer: From<I>,
{
T::new(Ty::Field, pf_lit_ir(i))
}
pub fn z_bool_lit(v: bool) -> T {
T::new(Ty::Bool, leaf_term(Op::Const(Value::Bool(v))))
}
pub fn uint_lit<I>(v: I, bits: usize) -> T
where
Integer: From<I>,
{
T::new(Ty::Uint(bits), bv_lit(v, bits))
}
pub fn slice(arr: T, start: Option<usize>, end: Option<usize>) -> Result<T, String> {
match &arr.ty {
Ty::Array(size, _) => {
let start = start.unwrap_or(0);
let end = end.unwrap_or(list.len());
Ok(T::Array(b, list.drain(start..end).collect()))
let end = end.unwrap_or(*size);
array(arr.unwrap_array()?.drain(start..end))
}
a => Err(format!("Cannot slice {}", a)),
}
}
pub fn field_select(struct_: &T, field: &str) -> Result<T, String> {
match struct_ {
T::Struct(_, map) => map
.get(field)
.cloned()
.ok_or_else(|| format!("No field '{}'", field)),
match &struct_.ty {
Ty::Struct(_, map) => {
if let Some((idx, ty)) = map.search(field) {
Ok(T::new(
ty.clone(),
term![Op::Field(idx); struct_.term.clone()],
))
} else {
Err(format!("No field '{}'", field))
}
}
a => Err(format!("{} is not a struct", a)),
}
}
pub fn field_store(struct_: T, field: &str, val: T) -> Result<T, String> {
match struct_ {
T::Struct(name, mut map) => Ok(T::Struct(name, {
if map.insert(field.to_owned(), val).is_some() {
map
match &struct_.ty {
Ty::Struct(_, map) => {
if let Some((idx, ty)) = map.search(field) {
if ty == &val.ty {
Ok(T::new(
struct_.ty.clone(),
term![Op::Update(idx); struct_.term.clone(), val.term],
))
} else {
Err(format!(
"term {} assigned to field {} of type {}",
val,
field,
map.get(idx).1
))
}
} else {
return Err(format!("No '{}' field", field));
Err(format!("No field '{}'", field))
}
})),
}
a => Err(format!("{} is not a struct", a)),
}
}
pub fn array_select(array: T, idx: T) -> Result<T, String> {
match (array, idx) {
(T::Array(_, list), T::Field(idx)) => {
let mut it = list.into_iter().enumerate();
let first = it
.next()
.ok_or_else(|| format!("Cannot index empty array"))?;
it.fold(Ok(first.1), |acc, (i, elem)| {
ite(term![Op::Eq; pf_lit(i), idx.clone()], elem, acc?)
})
match (array.ty, idx.ty) {
(Ty::Array(_size, elem_ty), Ty::Field) => {
Ok(T::new(*elem_ty, term![Op::Select; array.term, idx.term]))
}
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
}
}
pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> {
match (array, idx) {
(T::Array(ty, list), T::Field(idx)) => Ok(T::Array(
ty,
list.into_iter()
.enumerate()
.map(|(i, elem)| ite(term![Op::Eq; pf_lit(i), idx.clone()], val.clone(), elem))
.collect::<Result<Vec<_>, _>>()?,
match (&array.ty, idx.ty) {
(Ty::Array(_, _), Ty::Field) => Ok(T::new(
array.ty,
term![Op::Store; array.term, idx.term, val.term],
)),
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
}
}
fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
fn ir_array<I: IntoIterator<Item = Term>>(sort: Sort, elems: I) -> Term {
make_array(ZOK_FIELD_SORT.clone(), sort, elems.into_iter().collect())
}
pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
let v: Vec<T> = elems.into_iter().collect();
if let Some(e) = v.first() {
let ty = e.type_();
if v.iter().skip(1).any(|a| a.type_() != ty) {
Err(format!("Inconsistent types in array"))
} else {
Ok(T::Array(ty, v))
let sort = check(&e.term);
Ok(T::new(
Ty::Array(v.len(), Box::new(ty.clone())),
ir_array(sort, v.into_iter().map(|t| t.term)),
))
}
} else {
Err(format!("Empty array"))
@@ -508,27 +596,29 @@ fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
}
pub fn uint_to_bits(u: T) -> Result<T, String> {
match u {
T::Uint(n, t) => Ok(T::Array(
Ty::Bool,
(0..n)
.map(|i| T::Bool(term![Op::BvBit(i); t.clone()]))
.collect(),
match &u.ty {
Ty::Uint(n) => Ok(T::new(
Ty::Array(*n, Box::new(Ty::Bool)),
ir_array(
Sort::Bool,
(0..*n).map(|i| term![Op::BvBit(i); u.term.clone()]),
),
)),
u => Err(format!("Cannot do uint-to-bits on {}", u)),
}
}
pub fn uint_from_bits(u: T) -> Result<T, String> {
match u {
T::Array(Ty::Bool, list) => match list.len() {
8 | 16 | 32 => Ok(T::Uint(
list.len(),
match &u.ty {
Ty::Array(bits, elem_ty) if &**elem_ty == &Ty::Bool => match bits {
8 | 16 | 32 => Ok(T::new(
Ty::Uint(*bits),
term(
Op::BvConcat,
list.into_iter()
.map(|z: T| -> Result<Term, String> { Ok(term![Op::BoolToBv; bool(z)?]) })
.collect::<Result<Vec<_>, _>>()?,
u.unwrap_array_ir()?
.into_iter()
.map(|z: Term| -> Term { term![Op::BoolToBv; z] })
.collect(),
),
)),
l => Err(format!("Cannot do uint-from-bits on len {} array", l,)),
@@ -538,17 +628,9 @@ pub fn uint_from_bits(u: T) -> Result<T, String> {
}
pub fn field_to_bits(f: T) -> Result<T, String> {
match f {
T::Field(t) => {
let u = term![Op::PfToBv(254); t];
Ok(T::Array(
Ty::Bool,
(0..254)
.map(|i| T::Bool(term![Op::BvBit(i); u.clone()]))
.collect(),
))
}
u => Err(format!("Cannot do field-to-bits on {}", u)),
match &f.ty {
Ty::Field => uint_to_bits(T::new(Ty::Uint(254), term![Op::PfToBv(254); f.term])),
u => Err(format!("Cannot do uint-to-bits on {}", u)),
}
}
@@ -598,20 +680,26 @@ impl Embeddable for ZoKrates {
.unwrap_or_else(|| Integer::from(0))
};
match ty {
Ty::Bool => T::Bool(ctx.cs.borrow_mut().new_var(
&raw_name,
Sort::Bool,
|| Value::Bool(get_int_val() != 0),
visibility,
)),
Ty::Field => T::Field(ctx.cs.borrow_mut().new_var(
&raw_name,
Sort::Field(self.modulus.clone()),
|| Value::Field(FieldElem::new(get_int_val(), self.modulus.clone())),
visibility,
)),
Ty::Uint(w) => T::Uint(
*w,
Ty::Bool => T::new(
Ty::Bool,
ctx.cs.borrow_mut().new_var(
&raw_name,
Sort::Bool,
|| Value::Bool(get_int_val() != 0),
visibility,
),
),
Ty::Field => T::new(
Ty::Field,
ctx.cs.borrow_mut().new_var(
&raw_name,
Sort::Field(self.modulus.clone()),
|| Value::Field(FieldElem::new(get_int_val(), self.modulus.clone())),
visibility,
),
),
Ty::Uint(w) => T::new(
Ty::Uint(*w),
ctx.cs.borrow_mut().new_var(
&raw_name,
Sort::BitVector(*w),
@@ -619,23 +707,19 @@ impl Embeddable for ZoKrates {
visibility,
),
),
Ty::Array(n, ty) => T::Array(
(**ty).clone(),
(0..*n)
.map(|i| {
self.declare(
ctx,
&*ty,
idx_name(&raw_name, i),
user_name.as_ref().map(|u| idx_name(u, i)),
visibility.clone(),
)
})
.collect(),
),
Ty::Struct(n, fs) => T::Struct(
Ty::Array(n, ty) => array((0..*n).map(|i| {
self.declare(
ctx,
&*ty,
idx_name(&raw_name, i),
user_name.as_ref().map(|u| idx_name(u, i)),
visibility.clone(),
)
}))
.unwrap(),
Ty::Struct(n, fs) => T::new_struct(
n.clone(),
fs.iter()
fs.fields()
.map(|(f_name, f_ty)| {
(
f_name.clone(),
@@ -652,33 +736,8 @@ impl Embeddable for ZoKrates {
),
}
}
fn ite(&self, ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
match (t, f) {
(T::Bool(a), T::Bool(b)) => T::Bool(term![Op::Ite; cond, a, b]),
(T::Uint(wa, a), T::Uint(wb, b)) if wa == wb => T::Uint(wa, term![Op::Ite; cond, a, b]),
(T::Field(a), T::Field(b)) => T::Field(term![Op::Ite; cond, a, b]),
(T::Array(a_ty, a), T::Array(b_ty, b)) if a_ty == b_ty => T::Array(
a_ty,
a.into_iter()
.zip(b.into_iter())
.map(|(a_i, b_i)| self.ite(ctx, cond.clone(), a_i, b_i))
.collect(),
),
(T::Struct(a_nm, a), T::Struct(b_nm, b)) if a_nm == b_nm => T::Struct(
a_nm,
a.into_iter()
.zip(b.into_iter())
.map(|((a_f, a_i), (b_f, b_i))| {
if a_f == b_f {
(a_f, self.ite(ctx, cond.clone(), a_i, b_i))
} else {
panic!("Field mismatch: '{}' vs '{}'", a_f, b_f)
}
})
.collect(),
),
(t, f) => panic!("Cannot ITE {} and {}", t, f),
}
fn ite(&self, _ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
ite(cond, t, f).unwrap()
}
fn assign(
&self,
@@ -688,47 +747,15 @@ impl Embeddable for ZoKrates {
t: Self::T,
visibility: Option<PartyId>,
) -> Self::T {
assert!(&t.type_() == ty);
match (ty, t) {
(_, T::Bool(b)) => T::Bool(ctx.cs.borrow_mut().assign(&name, b, visibility)),
(_, T::Field(b)) => T::Field(ctx.cs.borrow_mut().assign(&name, b, visibility)),
(_, T::Uint(w, b)) => T::Uint(w, ctx.cs.borrow_mut().assign(&name, b, visibility)),
(_, T::Array(ety, list)) => T::Array(
ety.clone(),
list.into_iter()
.enumerate()
.map(|(i, elem)| {
self.assign(ctx, &ety, idx_name(&name, i), elem, visibility.clone())
})
.collect(),
),
(Ty::Struct(_, tys), T::Struct(s_name, list)) => T::Struct(
s_name,
list.into_iter()
.zip(tys.into_iter())
.map(|((f_name, elem), (_, f_ty))| {
(
f_name.clone(),
self.assign(
ctx,
&f_ty,
field_name(&name, &f_name),
elem,
visibility.clone(),
),
)
})
.collect(),
),
_ => unimplemented!(),
}
assert!(t.type_() == ty);
T::new(t.ty, ctx.cs.borrow_mut().assign(&name, t.term, visibility))
}
fn values(&self) -> bool {
self.values.is_some()
}
fn type_of(&self, term: &Self::T) -> Self::Ty {
term.type_()
term.type_().clone()
}
fn initialize_return(&self, ty: &Self::Ty, _ssa_name: &String) -> Self::T {

View File

@@ -1,110 +1,95 @@
//! Linear Memory implementation.
//!
//! The idea is to replace each array with a term sequence and use ITEs to linearly scan the array
//! when needed. A SELECT produces an ITE reduce chain, a STORE produces an ITE map over the
//! sequence.
//!
//! E.g., for length-3 arrays.
//!
//! (select A k) => (ite (= k 2) A2 (ite (= k 1) A1 A0))
//! (store A k v) => (ite (= k 0) v A0), (ite (= k 1) v A1), (ite (= k 2))
use super::visit::MemVisitor;
//! The idea is to replace each array with a tuple, and use ITEs to account for variable indexing.
use super::super::visit::RewritePass;
use crate::ir::term::*;
use std::iter::repeat;
struct Linearizer;
struct ArrayLinearizer {
/// A map from (original) replaced terms, to what they were replaced with.
sequences: TermMap<Vec<Term>>,
/// The maximum size of arrays that will be replaced.
size_thresh: usize,
fn arr_val_to_tup(v: &Value) -> Value {
match v {
Value::Array(Array {
default, map, size, ..
}) => Value::Tuple({
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
for (i, v) in map {
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
}
vec
}),
v => v.clone(),
}
}
impl MemVisitor for ArrayLinearizer {
fn visit_const_array(&mut self, orig: &Term, _key_sort: &Sort, val: &Term, size: usize) {
if size <= self.size_thresh {
self.sequences
.insert(orig.clone(), repeat(val).cloned().take(size).collect());
}
fn arr_sort_to_tup(v: &Sort) -> Sort {
match v {
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
v => v.clone(),
}
fn visit_eq(&mut self, orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
// don't map b/c self borrow lifetime & NLL
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
let b_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent eq");
let eqs: Vec<Term> = a_seq
.iter()
.zip(b_seq.iter())
.map(|(a, b)| term![Op::Eq; a.clone(), b.clone()])
.collect();
Some(term(Op::BoolNaryOp(BoolNaryOp::And), eqs))
} else {
None
}
}
fn visit_ite(&mut self, orig: &Term, c: &Term, _t: &Term, _f: &Term) {
if let Some(a_seq) = self.sequences.get(&orig.cs[1]) {
let b_seq = self.sequences.get(&orig.cs[2]).expect("inconsistent ite");
let ites: Vec<Term> = a_seq
.iter()
.zip(b_seq.iter())
.map(|(a, b)| term![Op::Ite; c.clone(), a.clone(), b.clone()])
.collect();
self.sequences.insert(orig.clone(), ites);
}
}
fn visit_store(&mut self, orig: &Term, _a: &Term, k: &Term, v: &Term) {
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
let key_sort = check(k);
let ites: Vec<Term> = a_seq
.iter()
.zip(key_sort.elems_iter())
.map(|(a_i, key_i)| {
let eq_idx = term![Op::Eq; key_i, k.clone()];
term![Op::Ite; eq_idx, v.clone(), a_i.clone()]
})
.collect();
self.sequences.insert(orig.clone(), ites);
}
}
fn visit_select(&mut self, orig: &Term, _a: &Term, k: &Term) -> Option<Term> {
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
let key_sort = check(k);
let first = a_seq.first().expect("empty array in visit_select").clone();
Some(a_seq.iter().zip(key_sort.elems_iter()).skip(1).fold(
first,
|acc, (a_i, key_i)| {
let eq_idx = term![Op::Eq; key_i, k.clone()];
term![Op::Ite; eq_idx, a_i.clone(), acc]
},
))
} else {
None
}
}
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
if let Sort::Array(_k, v, size) = s {
if *size <= self.size_thresh {
self.sequences.insert(
orig.clone(),
(0..*size)
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
.collect(),
);
}
impl RewritePass for Linearizer {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term> {
match &orig.op {
Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))),
Op::Var(name, sort @ Sort::Array(_k, _v, _size)) => {
let new_value = computation
.values
.as_ref()
.map(|vs| arr_val_to_tup(vs.get(name).unwrap()));
let vis = computation.metadata.get_input_visibility(name);
let new_sort = arr_sort_to_tup(sort);
let new_var_info = vec![(name.clone(), new_sort.clone(), new_value, vis)];
computation.replace_input(orig.clone(), new_var_info);
Some(leaf_term(Op::Var(name.clone(), new_sort)))
}
} else {
unreachable!("should only visit array vars")
Op::Select => {
let cs = rewritten_children();
let idx = &cs[1];
let tup = &cs[0];
if let Sort::Array(key_sort, _, size) = check(&orig.cs[0]) {
assert!(size > 0);
let mut fields = (0..size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(key_sort.elems_iter().take(size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
} else {
unreachable!()
}
}
Op::Store => {
let cs = rewritten_children();
let tup = &cs[0];
let idx = &cs[1];
let val = &cs[2];
if let Sort::Array(key_sort, _, size) = check(&orig.cs[0]) {
assert!(size > 0);
let mut updates =
(0..size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
let first = updates.next().unwrap();
Some(key_sort.elems_iter().take(size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], update, acc]
}))
} else {
unreachable!()
}
}
// ITEs and EQs are correctly rewritten by default.
_ => None,
}
}
}
/// Eliminate arrays using linear scans. See module documentation.
pub fn linearize(t: &Term, size_thresh: usize) -> Term {
let mut pass = ArrayLinearizer {
size_thresh,
sequences: TermMap::new(),
};
pass.traverse(t)
pub fn linearize(c: &mut Computation) {
let mut pass = Linearizer;
pass.traverse(c);
}
#[cfg(test)]
@@ -138,7 +123,12 @@ mod test {
#[test]
fn select_ite_stores() {
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
let z = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let t = term![Op::Select;
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
@@ -147,14 +137,21 @@ mod test {
],
bv_lit(3, 4)
];
let tt = linearize(&t, 6);
assert!(array_free(&tt));
assert_eq!(6 + 6 + 6 + 5, count_ites(&tt));
let mut c = Computation::default();
c.outputs.push(t);
linearize(&mut c);
assert!(array_free(&c.outputs[0]));
assert_eq!(5 + 5 + 1 + 5, count_ites(&c.outputs[0]));
}
#[test]
fn select_ite_stores_field() {
let z = term![Op::ConstArray(Sort::Field(Arc::new(Integer::from(TEST_FIELD))), 6); bv_lit(0, 4)];
let z = term![Op::Const(Value::Array(Array::new(
Sort::Field(Arc::new(Integer::from(TEST_FIELD))),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let t = term![Op::Select;
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
@@ -163,8 +160,10 @@ mod test {
],
field_lit(3)
];
let tt = linearize(&t, 6);
assert!(array_free(&tt));
assert_eq!(6 + 6 + 6 + 5, count_ites(&tt));
let mut c = Computation::default();
c.outputs.push(t);
linearize(&mut c);
assert!(array_free(&c.outputs[0]));
assert_eq!(5 + 5 + 1 + 5, count_ites(&c.outputs[0]));
}
}

View File

@@ -1,12 +1,10 @@
//! Memory optimizations
/// Oblivious array elimination.
///
/// Replace arrays with tuples, using ITEs to handle variable indexing.
pub mod lin;
/// Oblivious array elimination.
///
/// Replace arrays that are accessed at constant indices with tuples.
pub mod obliv;
mod visit;
use crate::ir::term::*;
/// Eliminates arrays, first oblivious ones, and then all arrays.
pub fn array_elim(t: &Term) -> Term {
lin::linearize(&obliv::elim_obliv(t), usize::MAX)
}

View File

@@ -1,12 +1,12 @@
//! Oblivious Array Elimination
//!
//! This module attempts to identify *oblivious* arrays: those that are only accessed at constant
//! indices. These arrays can be replaced with normal terms.
//! indices. These arrays can be replaced with tuples. Then, a tuple elimination pass can be run.
//!
//! It operates in two passes:
//!
//! 1. determine which arrays are oblivious
//! 2. replace oblivious arrays with (haskell) lists of terms.
//! 2. replace oblivious arrays with tuples
//!
//!
//! ## Pass 1: Identifying oblivious arrays
@@ -26,104 +26,116 @@
//!
//! In this pass, the goal is to
//!
//! * map array terms to haskell lists of value terms
//! * map array selections to specific value terms
//!
//! The pass maintains:
//!
//! * a map from array terms to lists of values
//!
//! It then does a bottom-up formula traversal, performing the following
//! transformations:
//!
//! * oblivious array variables are mapped to a list of (derivative) variables
//! * oblivious constant arrays are mapped to a list that replicates the constant
//! * accesses to oblivious arrays are mapped to the appropriate term from the
//! value list of the array
//! * stores to oblivious arrays are mapped to updated value lists
//! * equalities between oblivious arrays are mapped to conjunctions of equalities
//! * map array terms to tuple terms
//! * map array selections to tuple field gets
use super::visit::*;
use crate::ir::term::*;
use super::super::visit::*;
use crate::ir::term::extras::as_uint_constant;
use crate::ir::term::*;
use log::debug;
use std::iter::repeat;
struct NonOblivComputer {
not_obliv: TermSet,
progress: bool,
}
impl NonOblivComputer {
fn mark(&mut self, a: &Term) {
if !self.not_obliv.contains(a) {
self.not_obliv.insert(a.clone());
self.progress = true;
fn mark(&mut self, a: &Term) -> bool {
if !a.is_const() && self.not_obliv.insert(a.clone()) {
debug!("Not obliv: {}", a);
true
} else {
false
}
}
fn bi_implicate(&mut self, a: &Term, b: &Term) {
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
(false, true) => {
self.not_obliv.insert(a.clone());
self.progress = true;
fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool {
if !a.is_const() && !b.is_const() {
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
(false, true) => {
self.not_obliv.insert(a.clone());
true
}
(true, false) => {
self.not_obliv.insert(b.clone());
true
}
_ => false,
}
(true, false) => {
self.not_obliv.insert(b.clone());
self.progress = true;
}
_ => {}
} else {
false
}
}
fn new() -> Self {
Self {
progress: true,
not_obliv: TermSet::new(),
}
}
}
impl MemVisitor for NonOblivComputer {
fn visit_eq(&mut self, _orig: &Term, a: &Term, b: &Term) -> Option<Term> {
self.bi_implicate(a, b);
None
}
fn visit_ite(&mut self, orig: &Term, _c: &Term, t: &Term, f: &Term) {
self.bi_implicate(orig, t);
self.bi_implicate(t, f);
self.bi_implicate(orig, f);
}
fn visit_store(&mut self, orig: &Term, a: &Term, k: &Term, _v: &Term) {
if let Op::Const(_) = k.op {
self.bi_implicate(orig, a);
} else {
self.mark(a);
self.mark(orig);
impl ProgressAnalysisPass for NonOblivComputer {
fn visit(&mut self, term: &Term) -> bool {
match &term.op {
Op::Store => {
let a = &term.cs[0];
let i = &term.cs[1];
let v = &term.cs[2];
let mut progress = false;
if let Sort::Array(..) = check(v) {
// Imprecisely, mark v as non-obliv iff the array is.
progress = self.bi_implicate(term, v) || progress;
}
if let Op::Const(_) = i.op {
progress = self.bi_implicate(term, a) || progress;
} else {
progress = self.mark(a) || progress;
progress = self.mark(term) || progress;
}
if let Sort::Array(..) = check(v) {
// Imprecisely, mark v as non-obliv iff the array is.
progress = self.bi_implicate(term, v) || progress;
}
progress
}
Op::Select => {
let a = &term.cs[0];
let i = &term.cs[1];
if let Op::Const(_) = i.op {
false
} else {
self.mark(a)
}
}
Op::Ite => {
let t = &term.cs[1];
let f = &term.cs[2];
if let Sort::Array(..) = check(t) {
let mut progress = self.bi_implicate(term, t);
progress = self.bi_implicate(t, f) || progress;
progress = self.bi_implicate(term, f) || progress;
progress
} else {
false
}
}
Op::Eq => {
let a = &term.cs[0];
let b = &term.cs[1];
if let Sort::Array(..) = check(a) {
self.bi_implicate(a, b)
} else {
false
}
}
Op::Tuple => {
panic!("Tuple in obliv")
}
_ => false,
}
}
fn visit_select(&mut self, _orig: &Term, a: &Term, k: &Term) -> Option<Term> {
if let Op::Const(_) = k.op {
} else {
self.mark(a);
}
None
}
}
impl ProgressVisitor for NonOblivComputer {
fn check_progress(&self) -> bool {
self.progress
}
fn reset_progress(&mut self) {
self.progress = false;
}
}
struct Replacer {
/// A map from (original) replaced terms, to what they were replaced with.
sequences: TermMap<Vec<Term>>,
/// The maximum size of arrays that will be replaced.
not_obliv: TermSet,
}
@@ -133,96 +145,120 @@ impl Replacer {
!self.not_obliv.contains(a)
}
}
fn arr_val_to_tup(v: &Value) -> Value {
match v {
Value::Array(Array {
default, map, size, ..
}) => Value::Tuple({
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
for (i, v) in map {
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
}
vec
}),
v => v.clone(),
}
}
impl MemVisitor for Replacer {
fn visit_const_array(&mut self, orig: &Term, _key_sort: &Sort, val: &Term, size: usize) {
if self.should_replace(orig) {
debug!("Will replace constant: {}", orig);
self.sequences
.insert(orig.clone(), repeat(val).cloned().take(size).collect());
}
fn term_arr_val_to_tup(a: Term) -> Term {
match &a.op {
Op::Const(v @ Value::Array(..)) => leaf_term(Op::Const(arr_val_to_tup(v))),
_ => a,
}
fn visit_eq(&mut self, orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
let b_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent eq");
let eqs: Vec<Term> = a_seq
.iter()
.zip(b_seq.iter())
.map(|(a, b)| term![Op::Eq; a.clone(), b.clone()])
.collect();
Some(term(Op::BoolNaryOp(BoolNaryOp::And), eqs))
} else {
None
}
}
fn arr_sort_to_tup(v: &Sort) -> Sort {
match v {
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
v => v.clone(),
}
fn visit_ite(&mut self, orig: &Term, c: &Term, _t: &Term, _f: &Term) {
if self.should_replace(orig) {
let a_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent ite");
let b_seq = self.sequences.get(&orig.cs[2]).expect("inconsistent ite");
let ites: Vec<Term> = a_seq
.iter()
.zip(b_seq.iter())
.map(|(a, b)| term![Op::Ite; c.clone(), a.clone(), b.clone()])
.collect();
self.sequences.insert(orig.clone(), ites);
}
}
fn visit_store(&mut self, orig: &Term, _a: &Term, k: &Term, v: &Term) {
if self.should_replace(orig) {
let mut a_seq = self
.sequences
.get(&orig.cs[0])
.expect("inconsistent store")
.clone();
let k_const = as_uint_constant(k)
.expect("not obliv!")
.to_usize()
.expect("oversize index");
a_seq[k_const] = v.clone();
self.sequences.insert(orig.clone(), a_seq);
}
}
fn visit_select(&mut self, orig: &Term, _a: &Term, k: &Term) -> Option<Term> {
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
debug!("Will replace select: {}", orig);
let k_const = as_uint_constant(k)
.expect("not obliv!")
.to_usize()
.expect("oversize index");
if k_const < a_seq.len() {
Some(a_seq[k_const].clone())
} else {
panic!("Oversize index: {}", k_const)
}
#[track_caller]
fn get_const(t: &Term) -> usize {
as_uint_constant(t)
.unwrap_or_else(|| panic!("non-const {}", t))
.to_usize()
.expect("oversize")
}
impl RewritePass for Replacer {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term> {
//debug!("Visit {}", extras::Letified(orig.clone()));
let get_cs = || -> Vec<Term> {
rewritten_children()
.into_iter()
.map(term_arr_val_to_tup)
.collect()
};
match &orig.op {
Op::Var(name, sort @ Sort::Array(_k, _v, _size)) => {
if self.should_replace(orig) {
let new_value = computation
.values
.as_ref()
.map(|vs| arr_val_to_tup(vs.get(name).unwrap()));
let vis = computation.metadata.get_input_visibility(name);
let new_sort = arr_sort_to_tup(sort);
let new_var_info = vec![(name.clone(), new_sort.clone(), new_value, vis)];
computation.replace_input(orig.clone(), new_var_info);
Some(leaf_term(Op::Var(name.clone(), new_sort)))
} else {
None
}
}
} else {
None
}
}
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
if let Sort::Array(_k, v, size) = s {
if self.should_replace(orig) {
self.sequences.insert(
orig.clone(),
(0..*size)
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
.collect(),
);
Op::Select => {
if self.should_replace(&orig.cs[0]) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 2);
let k_const = get_const(&cs.pop().unwrap());
Some(term(Op::Field(k_const), cs))
} else {
None
}
}
} else {
unreachable!("should only visit array vars")
Op::Store => {
if self.should_replace(&orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 3);
let k_const = get_const(&cs.remove(1));
Some(term(Op::Update(k_const), cs))
} else {
None
}
}
Op::Ite => {
if self.should_replace(&orig) {
Some(term(Op::Ite, get_cs()))
} else {
None
}
}
Op::Eq => {
if self.should_replace(&orig.cs[0]) {
Some(term(Op::Eq, get_cs()))
} else {
None
}
}
_ => None,
}
}
}
/// Eliminate oblivious arrays. See module documentation.
pub fn elim_obliv(t: &Term) -> Term {
pub fn elim_obliv(t: &mut Computation) {
let mut prop_pass = NonOblivComputer::new();
prop_pass.traverse_to_fixpoint(t);
prop_pass.traverse(t);
let mut replace_pass = Replacer {
not_obliv: prop_pass.not_obliv,
sequences: TermMap::new(),
};
replace_pass.traverse(t)
<Replacer as RewritePass>::traverse(&mut replace_pass, t)
}
#[cfg(test)]
@@ -244,7 +280,12 @@ mod test {
#[test]
fn obliv() {
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
let z = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let t = term![Op::Select;
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
@@ -253,13 +294,20 @@ mod test {
],
bv_lit(3, 4)
];
let tt = elim_obliv(&t);
assert!(array_free(&tt));
let mut c = Computation::default();
c.outputs.push(t);
elim_obliv(&mut c);
assert!(array_free(&c.outputs[0]));
}
#[test]
fn not_obliv() {
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
let z = term![Op::Const(Value::Array(Array::new(
Sort::BitVector(4),
Box::new(Sort::BitVector(4).default_value()),
Default::default(),
6
)))];
let t = term![Op::Select;
term![Op::Ite;
leaf_term(Op::Const(Value::Bool(true))),
@@ -268,7 +316,9 @@ mod test {
],
bv_lit(3, 4)
];
let tt = elim_obliv(&t);
assert!(!array_free(&tt));
let mut c = Computation::default();
c.outputs.push(t);
elim_obliv(&mut c);
assert!(!array_free(&c.outputs[0]));
}
}

View File

@@ -1,107 +0,0 @@
use crate::ir::term::*;
use log::debug;
/// A visitor for traversing terms, and visiting the array-related parts.
///
/// Visits:
/// * EQs over arrays
/// * Constant arrays
/// * ITEs over arrays
/// * array variables
/// * STOREs
/// * SELECTs
///
/// For the EQs and SELECTs, you have the ability to (optionally) return a replacement for the
/// term. This can be used to "cut" the array out of the term, since EQs and SELECTs are how
/// information leaves an array.
///
/// All visitors receive the original term and the rewritten children.
pub trait MemVisitor {
/// Visit a const array
fn visit_const_array(&mut self, _orig: &Term, _key_sort: &Sort, _val: &Term, _size: usize) {}
/// Visit an equality, whose children are `a` and `b`.
fn visit_eq(&mut self, _orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
None
}
/// Visit an array-valued ITE
fn visit_ite(&mut self, _orig: &Term, _c: &Term, _t: &Term, _f: &Term) {}
/// Visit a STORE
fn visit_store(&mut self, _orig: &Term, _a: &Term, _k: &Term, _v: &Term) {}
/// Visit a SELECT
fn visit_select(&mut self, _orig: &Term, _a: &Term, _k: &Term) -> Option<Term> {
None
}
fn visit_var(&mut self, _orig: &Term, _name: &String, _s: &Sort) {}
/// Traverse a node, visiting memory-related terms.
///
/// Can be used to remove memory-related terms by replacing the EQs and SELECTs which extract
/// other terms from them.
///
/// Returns the transformed term.
fn traverse(&mut self, node: &Term) -> Term {
let mut cache = TermMap::<Term>::new();
for t in PostOrderIter::new(node.clone()) {
let c_get = |x: &Term| cache.get(x).unwrap();
let get = |i: usize| c_get(&t.cs[i]);
let new_t_opt = {
let s = check(&t);
match &s {
Sort::Array(_, _, _) => {
match &t.op {
Op::Var(name, s) => {
self.visit_var(&t, name, s);
}
Op::Ite => {
self.visit_ite(&t, get(0), get(1), get(2));
}
Op::Store => {
self.visit_store(&t, get(0), get(1), get(2));
}
Op::ConstArray(s, n) => {
self.visit_const_array(&t, s, get(0), *n);
}
_ => {}
};
None
}
_ => match &t.op {
Op::Eq => {
let a = get(0);
if let Sort::Array(_, _, _) = check(&a) {
self.visit_eq(&t, a, get(1))
} else {
None
}
}
Op::Select => self.visit_select(&t, get(0), get(1)),
_ => None,
},
}
};
let new_t = new_t_opt.unwrap_or_else(|| {
term(
t.op.clone(),
t.cs.iter().map(|c| c_get(c).clone()).collect(),
)
});
debug!("rebuild: {}", new_t);
cache.insert(t.clone(), new_t);
}
cache.remove(node).unwrap()
}
}
pub trait ProgressVisitor: MemVisitor {
fn reset_progress(&mut self);
fn check_progress(&self) -> bool;
fn traverse_to_fixpoint(&mut self, a: &Term) {
self.traverse(a);
while self.check_progress() {
self.reset_progress();
self.traverse(a);
}
}
}

View File

@@ -3,8 +3,10 @@ pub mod cfold;
pub mod flat;
pub mod inline;
pub mod mem;
pub mod scalarize_vars;
pub mod sha;
pub mod tuple;
mod visit;
use super::term::*;
use log::debug;
@@ -12,14 +14,19 @@ use log::debug;
#[derive(Debug)]
/// An optimization pass
pub enum Opt {
/// Convert non-scalar (tuple, array) inputs to scalar ones
/// The scalar variable names are suffixed with .N, where N indicates the array/tuple position
ScalarizeVars,
/// Fold constants
ConstantFold,
/// Flatten n-ary operators
Flatten,
/// SHA-2 peephole optimizations
Sha,
/// Memory elimination
Mem,
/// Replace oblivious arrays with tuples
Obliv,
/// Replace arrays with linear scans
LinearScan,
/// Extract top-level ANDs as distinct outputs
FlattenAssertions,
/// Find outputs like `(= variable term)`, and substitute out `variable`
@@ -33,6 +40,9 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
for i in optimizations {
debug!("Applying: {:?}", i);
match i {
Opt::ScalarizeVars => {
scalarize_vars::scalarize_inputs(&mut cs);
}
Opt::ConstantFold => {
let mut cache = TermMap::new();
for a in &mut cs.outputs {
@@ -44,10 +54,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
*a = sha::sha_rewrites(a);
}
}
Opt::Mem => {
for a in &mut cs.outputs {
*a = mem::array_elim(a);
}
Opt::Obliv => {
mem::obliv::elim_obliv(&mut cs);
}
Opt::LinearScan => {
mem::lin::linearize(&mut cs);
}
Opt::FlattenAssertions => {
let mut new_outputs = Vec::new();
@@ -68,14 +79,19 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
}
}
Opt::Inline => {
let public_inputs = cs.metadata.public_inputs().map(ToOwned::to_owned).collect();
let public_inputs = cs
.metadata
.public_input_names()
.map(ToOwned::to_owned)
.collect();
inline::inline(&mut cs.outputs, &public_inputs);
}
Opt::Tuple => {
cs = tuple::eliminate_tuples(cs);
tuple::eliminate_tuples(&mut cs);
}
}
debug!("After {:?}: {} outputs", i, cs.outputs.len());
//debug!("After {:?}: {}", i, Letified(cs.outputs[0].clone()));
debug!("After {:?}: {} terms", i, cs.terms());
}
garbage_collect();

View File

@@ -0,0 +1,132 @@
//! Replacing array and tuple variables with scalars.
use crate::ir::opt::visit::RewritePass;
use crate::ir::term::*;
struct Pass;
fn create_vars(
prefix: &str,
sort: &Sort,
value: Option<Value>,
party: Option<PartyId>,
new_var_requests: &mut Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
) -> Term {
match sort {
Sort::Tuple(sorts) => {
let mut values = value.map(|v| match v {
Value::Tuple(t) => t,
_ => panic!(),
});
term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
create_vars(
&format!("{}.{}", prefix, i),
sort,
values
.as_mut()
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
party,
new_var_requests,
)
})
.collect(),
)
}
Sort::Array(key_s, val_s, size) => {
let mut values = value.map(|v| match v {
Value::Array(Array {
default, map, size, ..
}) => {
let mut vals = vec![*default; size];
for (key_val, val_val) in map.into_iter() {
let idx = key_val.as_usize().unwrap();
vals[idx] = val_val;
}
vals
}
_ => panic!(),
});
make_array(
(**key_s).clone(),
(**val_s).clone(),
(0..*size)
.map(|i| {
create_vars(
&format!("{}.{}", prefix, i),
val_s,
values
.as_mut()
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
party,
new_var_requests,
)
})
.collect(),
)
}
_ => {
new_var_requests.push((prefix.into(), sort.clone(), value, party));
leaf_term(Op::Var(prefix.into(), sort.clone()))
}
}
}
impl RewritePass for Pass {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
orig: &Term,
_rewritten_children: F,
) -> Option<Term> {
if let Op::Var(name, sort) = &orig.op {
let party_visibility = computation.metadata.get_input_visibility(name);
let mut new_var_reqs = Vec::new();
let new = create_vars(
name,
sort,
computation
.values
.as_ref()
.map(|v| v.get(name).unwrap().clone()),
party_visibility,
&mut new_var_reqs,
);
if new_var_reqs.len() > 0 {
computation.replace_input(orig.clone(), new_var_reqs);
}
Some(new)
} else {
None
}
}
}
/// Run the tuple elimination pass.
pub fn scalarize_inputs(cs: &mut Computation) {
let mut pass = Pass;
pass.traverse(cs);
#[cfg(debug_assertions)]
assert_all_vars_are_scalars(cs);
}
/// Check that every variables is a scalar.
pub fn assert_all_vars_are_scalars(cs: &Computation) {
for t in cs
.terms_postorder()
.into_iter()
.chain(cs.metadata.inputs.iter().cloned())
{
if let Op::Var(_name, sort) = &t.op {
match sort {
Sort::Array(..) | Sort::Tuple(..) => {
panic!("Variable {} is non-scalar", t);
}
_ => {}
}
}
}
}

View File

@@ -170,9 +170,9 @@ mod test {
#[test]
fn undo() {
let a = bv_lit(0, 1);
let b = bv_lit(0, 1);
let c = bv_lit(0, 1);
let a = bool_lit(false);
let b = bool_lit(false);
let c = bool_lit(false);
let t = term![Op::BoolMaj; a.clone(), b.clone(),c.clone()];
let tt = term![OR; term![AND; a.clone(), b.clone()], term![AND; b.clone(), c.clone()], term![AND; c.clone(), a.clone()]];
assert_eq!(tt, sha_maj_elim(&t));

View File

@@ -2,225 +2,268 @@
//!
//! Elimates tuple-related terms.
//!
//! The idea is to do a bottom-up pass mapping all terms to tuple-free trees of terms.
//! The idea is to do a bottom-up pass, in which all tuple's are lift to the top level, and then
//! then removed.
//!
//! * Tuple variables are suffixed. `x: (bool, bool)` becomes `x.0: bool, x.1: bool`.
//! * Tuple constants are replaced with trees
//! * Tuple constructors make trees
//! * Tuple accesses open up trees
//! * Tuple ITEs yield trees of ITEs
//! * Tuple EQs yield conjunctions of EQs
use std::rc::Rc;
//! ## Phase 1
//!
//! Phase 1 (lifting tuples) is defined by the following big-step rewrites.
//!
//! Notational conventions:
//! * lowercase letters are used to match sorts/terms before rewriting.
//! * uppercase letters denote their (big-step) rewritten counterparts.
//! * () denote AST structure
//! * [] denote repeated structures, i.e. like var-args.
//! * `f(x, *)` denotes a partial application of `f`, i.e. the function that sends `y` to `f(x,y)`.
//! * In the results of rewriting we often have terms which are tuples at the top-level. I.e. their
//! sort contains no tuple sort that is not the child of a tuple sort. Similarly, the terms are
//! tuples at the top-level: no field or update operators are present, and the only tuple
//! operators are the children of other tuple operators.
//! * Such sorts/terms can be viewed as structured collections of non-tuple elements, i.e., as
//! functors whose pure elements are non-tuples.
//! * The `map`, `bimap`, and `list` functions apply to those functors.
//! * i.e., `map f (tuple x tuple y z))` is `(tuple (f x) (tuple (f y) (f z)))`.
//! * i.e., `list (tuple x tuple y z))` is `(x y z)`
//!
//! Assumptions:
//! * We assume that array keys are of scalar sort
//! * We assume that variables are of scalar sort. See [super::scalarize_vars].
//! * We *do not* describe the pass as applied to constant values. That part of the pass is
//! entirely analagous to terms.
//!
//! Sort rewrites:
//! * `(tuple [t_i]_i) -> (tuple [T_i]_i)`
//! * `(array k t) -> map (array k *) T
//!
//! Term rewrites:
//! * `(ite c t f) -> bimap (ite C * *) T F`
//! * `(eq c t f) -> (and (list (bimap (= * *) T F)))`
//! * `(tuple [t_i]_i) -> (tuple [T_i]_i)`
//! * `(field_j t) -> get T j`
//! * `(update_j t) -> update T j V`
//! * `(select a i) -> map (select * I) A`
//! * `(store a i v) -> bimap (store * I *) A V`
//! * `(OTHER [t_i]_i) -> (OTHER [T_i]_i)`
//! * constants: *omitted*
//!
//! The result of this phase is a computation whose only tuple-terms are at the top of the
//! computation graph
//!
//! ## Phase 2
//!
//! We replace each output `t` with the sequence of outputs `(list T)`.
use crate::ir::opt::visit::RewritePass;
use crate::ir::term::{
check, leaf_term, term, BoolNaryOp, Computation, Op, PartyId, PostOrderIter, Sort, Term,
TermMap, Value,
check, extras, leaf_term, term, Array, Computation, Op, PostOrderIter, Sort, Term, Value, AND,
};
use std::collections::BTreeMap;
type Tree = Rc<TreeData>;
use itertools::zip_eq;
#[derive(Clone, Debug)]
enum TreeData {
Leaf(Term),
Tuple(Vec<Tree>),
}
#[derive(Clone, PartialEq, Eq, Debug)]
struct TupleTree(Term);
impl TreeData {
fn unwrap_leaf(&self) -> &Term {
match self {
TreeData::Leaf(l) => l,
d => panic!("expected leaf, got {:?}", d),
}
}
fn unwrap_tuple(&self) -> &Vec<Tree> {
match self {
TreeData::Tuple(l) => l,
d => panic!("expected tuple, got {:?}", d),
}
}
fn unfold_tuple_into(&self, terms: &mut Vec<Term>) {
match self {
TreeData::Leaf(l) => terms.push(l.clone()),
TreeData::Tuple(l) => l.iter().for_each(|x| x.unfold_tuple_into(terms)),
}
}
fn unfold_tuple(&self) -> Vec<Term> {
let mut terms = Vec::new();
self.unfold_tuple_into(&mut terms);
terms
}
fn from_value(v: Value) -> TreeData {
match v {
Value::Tuple(vs) => {
TreeData::Tuple(vs.into_iter().map(Self::from_value).map(Rc::new).collect())
impl TupleTree {
fn flatten(&self) -> impl Iterator<Item = Term> {
let mut out = Vec::new();
fn rec_unroll_into(t: &Term, out: &mut Vec<Term>) {
if &t.op == &Op::Tuple {
for c in &t.cs {
rec_unroll_into(c, out);
}
} else {
out.push(t.clone());
}
v => TreeData::Leaf(leaf_term(Op::Const(v))),
}
rec_unroll_into(&self.0, &mut out);
out.into_iter()
}
fn structure(&self, flattened: impl IntoIterator<Item = Term>) -> Self {
fn term_structure(t: &Term, iter: &mut impl Iterator<Item = Term>) -> Term {
if &t.op == &Op::Tuple {
term(
Op::Tuple,
t.cs.iter().map(|c| term_structure(c, iter)).collect(),
)
} else {
iter.next().expect("bad structure")
}
}
Self(term_structure(&self.0, &mut flattened.into_iter()))
}
fn well_formed(&self) -> bool {
for t in PostOrderIter::new(self.0.clone()) {
if &t.op != &Op::Tuple {
for c in &t.cs {
if &c.op == &Op::Tuple {
return false;
}
}
}
}
true
}
#[allow(dead_code)]
fn assert_well_formed(&self) {
assert!(
self.well_formed(),
"The following is not a well-formed tuple tree {}",
extras::Letified(self.0.clone())
);
}
fn map(&self, f: impl FnMut(Term) -> Term) -> Self {
self.structure(self.flatten().map(f))
}
fn bimap(&self, mut f: impl FnMut(Term, Term) -> Term, other: &Self) -> Self {
self.structure(itertools::zip_eq(self.flatten(), other.flatten()).map(|(a, b)| f(a, b)))
}
fn get(&self, i: usize) -> Self {
assert_eq!(&self.0.op, &Op::Tuple);
assert!(i < self.0.cs.len());
Self(self.0.cs[i].clone())
}
fn update(&self, i: usize, v: &Term) -> Self {
assert_eq!(&self.0.op, &Op::Tuple);
assert!(i < self.0.cs.len());
let mut cs = self.0.cs.clone();
cs[i] = v.clone();
Self(term(Op::Tuple, cs))
}
}
fn restructure(sort: &Sort, items: &mut impl Iterator<Item = Term>) -> Tree {
if let Sort::Tuple(sorts) = sort {
Rc::new(TreeData::Tuple(
sorts.iter().map(|c| restructure(c, items)).collect(),
))
#[derive(Clone, PartialEq, Eq, Debug)]
struct ValueTupleTree(Value);
impl ValueTupleTree {
fn flatten(&self) -> Vec<Value> {
let mut out = Vec::new();
fn rec_unroll_into(t: &Value, out: &mut Vec<Value>) {
match t {
Value::Tuple(vs) => {
for c in vs {
rec_unroll_into(c, out);
}
}
_ => out.push(t.clone()),
}
}
rec_unroll_into(&self.0, &mut out);
out
}
fn structure(&self, flattened: impl IntoIterator<Item = Value>) -> Self {
fn term_structure(t: &Value, iter: &mut impl Iterator<Item = Value>) -> Value {
match t {
Value::Tuple(vs) => {
Value::Tuple(vs.iter().map(|c| term_structure(c, iter)).collect())
}
_ => iter.next().expect("bad structure"),
}
}
Self(term_structure(&self.0, &mut flattened.into_iter()))
}
}
fn termify_val_tuples(v: Value) -> Term {
if let Value::Tuple(vs) = v {
term(Op::Tuple, vs.into_iter().map(termify_val_tuples).collect())
} else {
Rc::new(TreeData::Leaf(
items
.next()
.unwrap_or_else(|| panic!("No term when building {}", sort)),
))
leaf_term(Op::Const(v))
}
}
struct Pass {
map: TermMap<Tree>,
cs: Computation,
}
impl Pass {
fn new(cs: Computation) -> Self {
Self {
map: TermMap::new(),
cs,
fn untuple_value(v: &Value) -> Value {
match v {
Value::Tuple(xs) => Value::Tuple(xs.iter().map(untuple_value).collect()),
Value::Array(Array {
key_sort,
default,
map,
size,
}) => {
let def = untuple_value(default);
let flat_def = ValueTupleTree(def.clone()).flatten();
let mut map: BTreeMap<_, _> = map
.iter()
.map(|(k, v)| (k, ValueTupleTree(untuple_value(v)).flatten()))
.collect();
let mut flat_out: Vec<Value> = flat_def
.into_iter()
.rev()
.map(|d| {
let mut submap: BTreeMap<Value, Value> = BTreeMap::new();
for (k, v) in &mut map {
submap.insert((**k).clone(), v.pop().unwrap());
}
Value::Array(Array::new(key_sort.clone(), Box::new(d), submap, *size))
})
.collect();
flat_out.reverse();
ValueTupleTree(def).structure(flat_out).0
}
_ => v.clone(),
}
}
#[track_caller]
fn get_tree(&self, t: &Term) -> &Tree {
self.map
.get(t)
.unwrap_or_else(|| panic!("missing tree for term: {} in {:?}", t, self.map))
}
struct TupleLifter;
fn create_vars(
impl RewritePass for TupleLifter {
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
prefix: &str,
sort: &Sort,
value: Option<Value>,
party: Option<PartyId>,
) -> Tree {
match sort {
Sort::Tuple(sorts) => {
let mut values = value.map(|v| match v {
Value::Tuple(t) => t,
_ => panic!(),
});
Rc::new(TreeData::Tuple(
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
self.create_vars(
&format!("{}.{}", prefix, i),
sort,
values
.as_mut()
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
party,
)
})
.collect(),
))
_computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term> {
match &orig.op {
Op::Const(v) => Some(termify_val_tuples(untuple_value(v))),
Op::Ite => {
let mut cs = rewritten_children();
let f = TupleTree(cs.pop().unwrap());
let t = TupleTree(cs.pop().unwrap());
let c = cs.pop().unwrap();
debug_assert!(cs.is_empty());
Some(t.bimap(|a, b| term![Op::Ite; c.clone(), a, b], &f).0)
}
_ => Rc::new(TreeData::Leaf(
if self.cs.metadata.is_input(prefix)
&& self.cs.metadata.is_input_public(prefix)
&& self
.cs
.values
.as_ref()
.map(|v| v.contains_key(prefix))
.unwrap_or(false)
{
leaf_term(Op::Var(prefix.into(), sort.clone()))
} else {
self.cs
.new_var(prefix, sort.clone(), || value.unwrap().clone(), party)
},
)),
}
}
fn embed_step(&mut self, t: &Term) {
let s = check(t);
let tree = if let Sort::Tuple(_) = s {
match &t.op {
Op::Const(v) => Rc::new(TreeData::from_value(v.clone())),
Op::Var(name, sort) => {
let party_visibility = self.cs.metadata.get_input_visibility(name);
self.create_vars(
name,
sort,
self.cs
.values
.as_ref()
.map(|v| v.get(name).unwrap().clone()),
party_visibility,
)
}
Op::Tuple => Rc::new(TreeData::Tuple(
t.cs.iter().map(|c| self.get_tree(c).clone()).collect(),
)),
Op::Ite => {
let cond = self.get_tree(&t.cs[0]).unwrap_leaf();
let trues = self.get_tree(&t.cs[1]).unfold_tuple();
let falses = self.get_tree(&t.cs[2]).unfold_tuple();
restructure(
&s,
&mut trues
.into_iter()
.zip(falses.into_iter())
.map(|(a, b)| term![Op::Ite; cond.clone(), a, b]),
)
}
Op::Field(i) => self.get_tree(&t.cs[0]).unwrap_tuple()[*i].clone(),
o => panic!("Bad tuple operator: {}", o),
Op::Eq => {
let mut cs = rewritten_children();
let b = TupleTree(cs.pop().unwrap());
let a = TupleTree(cs.pop().unwrap());
debug_assert!(cs.is_empty());
let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]);
Some(term(AND, eqs.collect()))
}
} else {
match &t.op {
Op::Field(i) => self.get_tree(&t.cs[0]).unwrap_tuple()[*i].clone(),
Op::Eq => {
let c_sort = check(&t.cs[0]);
Rc::new(TreeData::Leaf(if let Sort::Tuple(_) = c_sort {
let xs = self.get_tree(&t.cs[0]).unfold_tuple();
let ys = self.get_tree(&t.cs[1]).unfold_tuple();
term(
Op::BoolNaryOp(BoolNaryOp::And),
xs.into_iter()
.zip(ys.into_iter())
.map(|(x, y)| term![Op::Eq; x, y])
.collect(),
)
} else {
term(
t.op.clone(),
t.cs.iter()
.map(|c| self.get_tree(c).unwrap_leaf().clone())
.collect(),
)
}))
}
_ => Rc::new(TreeData::Leaf(term(
t.op.clone(),
t.cs.iter()
.map(|c| self.get_tree(c).unwrap_leaf().clone())
.collect(),
))),
Op::Store => {
let mut cs = rewritten_children();
let v = TupleTree(cs.pop().unwrap());
let i = cs.pop().unwrap();
let a = TupleTree(cs.pop().unwrap());
debug_assert!(cs.is_empty());
Some(a.bimap(|a, v| term![Op::Store; a, i.clone(), v], &v).0)
}
};
if let TreeData::Leaf(t) = &*tree {
debug_assert!(tuple_free(t.clone()), "Tuple in {:?}", tree);
Op::Select => {
let mut cs = rewritten_children();
let i = cs.pop().unwrap();
let a = TupleTree(cs.pop().unwrap());
debug_assert!(cs.is_empty());
Some(a.map(|a| term![Op::Select; a, i.clone()]).0)
}
Op::Field(i) => {
let mut cs = rewritten_children();
let t = TupleTree(cs.pop().unwrap());
debug_assert!(cs.is_empty());
Some(t.get(*i).0)
}
Op::Update(i) => {
let mut cs = rewritten_children();
let v = cs.pop().unwrap();
let t = TupleTree(cs.pop().unwrap());
debug_assert!(cs.is_empty());
Some(t.update(*i, &v).0)
}
// The default rewrite is correct here.
Op::Tuple => None,
_ => None,
}
self.map.insert(t.clone(), tree);
}
fn embed(&mut self, t: &Term) -> Tree {
for c in PostOrderIter::new(t.clone()) {
self.embed_step(&c);
}
self.get_tree(t).clone()
}
}
@@ -232,16 +275,15 @@ fn tuple_free(t: Term) -> bool {
}
/// Run the tuple elimination pass.
pub fn eliminate_tuples(mut cs: Computation) -> Computation {
let outputs = std::mem::take(&mut cs.outputs);
let mut pass = Pass::new(cs);
for output in &outputs {
pass.embed(&output);
}
let new_ouputs: Vec<Term> = outputs
.iter()
.map(|c| pass.get_tree(c).unwrap_leaf().clone())
pub fn eliminate_tuples(cs: &mut Computation) {
let mut pass = TupleLifter;
pass.traverse(cs);
cs.outputs = std::mem::take(&mut cs.outputs)
.into_iter()
.flat_map(|o| TupleTree(o).flatten())
.collect();
pass.cs.outputs = new_ouputs;
pass.cs
#[cfg(debug_assertions)]
for o in &cs.outputs {
assert!(tuple_free(o.clone()));
}
}

76
src/ir/opt/visit.rs Normal file
View File

@@ -0,0 +1,76 @@
use crate::ir::term::*;
/// A rewriting pass.
pub trait RewritePass {
/// Visit (and possibly rewrite) a term.
/// Given the original term and a function to get its rewritten childen.
/// Returns a term if a rewrite happens.
fn visit<F: Fn() -> Vec<Term>>(
&mut self,
computation: &mut Computation,
orig: &Term,
rewritten_children: F,
) -> Option<Term>;
fn traverse(&mut self, computation: &mut Computation) {
let mut cache = TermMap::<Term>::new();
let mut children_added = TermSet::new();
let mut stack = Vec::new();
stack.extend(computation.outputs.iter().cloned());
while let Some(top) = stack.pop() {
if !cache.contains_key(&top) {
// was it missing?
if children_added.insert(top.clone()) {
stack.push(top.clone());
stack.extend(top.cs.iter().filter(|c| !cache.contains_key(c)).cloned());
} else {
let get_children = || -> Vec<Term> {
top.cs
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect()
};
let new_t_opt = self.visit(computation, &top, get_children);
let new_t = new_t_opt.unwrap_or_else(|| term(top.op.clone(), get_children()));
cache.insert(top.clone(), new_t);
}
}
}
computation.outputs = computation
.outputs
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
}
}
/// An analysis pass that repeated sweeps all terms, visiting them, untill a pass makes no more
/// progress.
pub trait ProgressAnalysisPass {
/// The visit function. Returns whether progress was made.
fn visit(&mut self, term: &Term) -> bool;
/// Repeatedly sweep till progress is no longer made.
fn traverse(&mut self, computation: &Computation) {
let mut progress = true;
let mut order = Vec::new();
let mut visited = TermSet::new();
let mut stack = Vec::new();
stack.extend(computation.outputs.iter().cloned());
while let Some(top) = stack.pop() {
stack.extend(top.cs.iter().filter(|c| !visited.contains(c)).cloned());
// was it missing?
if visited.insert(top.clone()) {
order.push(top);
}
}
while progress {
progress = false;
for t in &order {
progress = self.visit(t) || progress;
}
for t in order.iter().rev() {
progress = self.visit(t) || progress;
}
}
}
}

View File

@@ -119,16 +119,6 @@ pub enum Op {
/// Takes the modulus.
UbvToPf(Arc<Integer>),
// key sort, size
/// A unary operator.
///
/// Make an array from keys of the given sort, which is equal to the provided argument at all
/// places.
///
/// Has space for the provided number of elements. Note that this assumes an order and starting
/// point for keys.
ConstArray(Sort, usize),
/// Binary operator, with arguments (array, index).
///
/// Gets the value at index in array.
@@ -142,6 +132,8 @@ pub enum Op {
Tuple,
/// Get the n'th element of a tuple
Field(usize),
/// Update (tuple, element)
Update(usize),
}
/// Boolean AND
@@ -247,11 +239,11 @@ impl Op {
Op::PfUnOp(_) => Some(1),
Op::PfNaryOp(_) => None,
Op::UbvToPf(_) => Some(1),
Op::ConstArray(_, _) => Some(1),
Op::Select => Some(2),
Op::Store => Some(3),
Op::Tuple => None,
Op::Field(_) => Some(1),
Op::Update(_) => Some(2),
}
}
}
@@ -289,11 +281,11 @@ impl Display for Op {
Op::PfUnOp(a) => write!(f, "{}", a),
Op::PfNaryOp(a) => write!(f, "{}", a),
Op::UbvToPf(a) => write!(f, "bv2pf {}", a),
Op::ConstArray(_, s) => write!(f, "const-array {}", s),
Op::Select => write!(f, "select"),
Op::Store => write!(f, "store"),
Op::Tuple => write!(f, "tuple"),
Op::Field(i) => write!(f, "field{}", i),
Op::Update(i) => write!(f, "update{}", i),
}
}
}
@@ -629,11 +621,61 @@ pub enum Value {
/// Boolean
Bool(bool),
/// Array
Array(Sort, Box<Value>, BTreeMap<Value, Value>, usize),
Array(Array),
/// Tuple
Tuple(Vec<Value>),
}
#[derive(Clone, PartialEq, Debug, PartialOrd, Hash)]
/// An IR array value.
///
/// A sized, space array.
pub struct Array {
/// Key sort
pub key_sort: Sort,
/// Default (fill) value. What is stored when a key is missing from the next member
pub default: Box<Value>,
/// Key-> Value map
pub map: BTreeMap<Value, Value>,
/// Size of array. There are this many valid keys.
pub size: usize,
}
impl Array {
/// Create a new [Array] from components
pub fn new(
key_sort: Sort,
default: Box<Value>,
map: BTreeMap<Value, Value>,
size: usize,
) -> Self {
Self {
key_sort,
default,
map,
size,
}
}
/// Create a new, default-initialized [Array]
pub fn default(key_sort: Sort, val_sort: &Sort, size: usize) -> Self {
Self::new(
key_sort,
Box::new(val_sort.default_value()),
Default::default(),
size,
)
}
/// Store
pub fn store(mut self, idx: Value, val: Value) -> Self {
self.map.insert(idx, val);
self
}
/// Select
pub fn select(&self, idx: &Value) -> Value {
self.map.get(idx).unwrap_or(&*self.default).clone()
}
}
impl Display for Value {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
@@ -650,13 +692,21 @@ impl Display for Value {
}
write!(f, ")")
}
Value::Array(_s, d, map, size) => {
write!(f, "(map default:{} size:{} {:?})", d, size, map)
}
Value::Array(a) => write!(f, "{}", a),
}
}
}
impl Display for Array {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"(map default:{} size:{} {:?})",
self.default, self.size, self.map
)
}
}
impl std::cmp::Eq for Value {}
impl std::cmp::Ord for Value {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
@@ -672,12 +722,7 @@ impl std::hash::Hash for Value {
Value::Int(bv) => bv.hash(state),
Value::Field(bv) => bv.hash(state),
Value::Bool(bv) => bv.hash(state),
Value::Array(s, d, a, size) => {
s.hash(state);
d.hash(state);
a.hash(state);
size.hash(state);
}
Value::Array(a) => a.hash(state),
Value::Tuple(s) => {
s.hash(state);
}
@@ -784,6 +829,37 @@ impl Sort {
_ => panic!("Cannot iterate over {}", self),
}
}
/// Compute the default term for this sort.
///
/// * booleans: false
/// * bit-vectors: zero
/// * field elements: zero
/// * floats: zero
/// * tuples/arrays: recursively default
pub fn default_term(&self) -> Term {
leaf_term(Op::Const(self.default_value()))
}
/// Compute the default value for this sort.
///
/// * booleans: false
/// * bit-vectors: zero
/// * field elements: zero
/// * floats: zero
/// * tuples/arrays: recursively default
pub fn default_value(&self) -> Value {
match self {
Sort::Bool => Value::Bool(false),
Sort::BitVector(w) => Value::BitVector(BitVector::new(0.into(), *w)),
Sort::Field(m) => Value::Field(FieldElem::new(Integer::from(0), m.clone())),
Sort::Int => Value::Int(0.into()),
Sort::F32 => Value::F32(0.0f32),
Sort::F64 => Value::F64(0.0),
Sort::Tuple(t) => Value::Tuple(t.iter().map(Sort::default_value).collect()),
Sort::Array(k, v, n) => Value::Array(Array::default((**k).clone(), v, *n)),
}
}
}
impl Display for Sort {
@@ -939,6 +1015,14 @@ impl TermData {
false
}
}
/// Is this a value
pub fn is_const(&self) -> bool {
if let Op::Const(..) = &self.op {
true
} else {
false
}
}
}
impl Value {
@@ -951,7 +1035,12 @@ impl Value {
Value::F64(_) => Sort::F64,
Value::F32(_) => Sort::F32,
Value::BitVector(b) => Sort::BitVector(b.width()),
Value::Array(s, _, _, _) => s.clone(),
Value::Array(Array {
key_sort,
default,
size,
..
}) => Sort::Array(Box::new(key_sort.clone()), Box::new(default.sort()), *size),
Value::Tuple(v) => Sort::Tuple(v.iter().map(Value::sort).collect()),
}
}
@@ -992,6 +1081,16 @@ impl Value {
}
}
#[track_caller]
/// Unwrap the constituent value of this array, panicking otherwise.
pub fn as_array(&self) -> &Array {
if let Value::Array(w) = self {
&w
} else {
panic!("{} is not an aray", self)
}
}
/// Get the underlying boolean constant, if possible.
pub fn as_bool_opt(&self) -> Option<bool> {
if let Value::Bool(b) = self {
@@ -1008,6 +1107,16 @@ impl Value {
None
}
}
/// Compute the sort of this value
pub fn as_usize(&self) -> Option<usize> {
match &self {
Value::Bool(b) => Some(*b as usize),
Value::Field(f) => f.i().to_usize(),
Value::Int(i) => i.to_usize(),
Value::BitVector(b) => b.uint().to_usize(),
_ => None,
}
}
}
/// Evaluate the term `t`, using variable values in `h`.
@@ -1148,12 +1257,32 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
field::FieldElem::new(a.uint().clone(), m.clone())
}),
// tuple
Op::Tuple => Value::Tuple(c.cs.iter().map(|c| vs.get(c).unwrap().clone()).collect()),
Op::Field(i) => {
let t = vs.get(&c.cs[0]).unwrap().as_tuple();
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs[0]);
t[*i].clone()
}
Op::Update(i) => {
let mut t = vs.get(&c.cs[0]).unwrap().as_tuple().clone();
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs[0]);
let e = vs.get(&c.cs[1]).unwrap().clone();
t[*i] = e;
Value::Tuple(t)
}
// array
Op::Store => {
let a = vs.get(&c.cs[0]).unwrap().as_array().clone();
let i = vs.get(&c.cs[1]).unwrap().clone();
let v = vs.get(&c.cs[2]).unwrap().clone();
Value::Array(a.clone().store(i, v))
}
Op::Select => {
let a = vs.get(&c.cs[0]).unwrap().as_array().clone();
let i = vs.get(&c.cs[1]).unwrap();
a.clone().select(i)
}
o => unimplemented!("eval: {:?}", o),
};
//println!("Eval {}\nAs {}", c, v);
@@ -1162,14 +1291,31 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
vs.get(t).unwrap().clone()
}
/// Make an array from a sequence of terms.
///
/// Requires
///
/// * a key sort, as all arrays do. This sort must be iterable (i.e., bool, int, bit-vector, or field).
/// * a value sort, for the array's default
pub fn make_array(key_sort: Sort, value_sort: Sort, i: Vec<Term>) -> Term {
let d = Sort::Array(Box::new(key_sort.clone()), Box::new(value_sort), i.len()).default_term();
i.into_iter()
.zip(key_sort.elems_iter())
.fold(d, |arr, (val, idx)| term(Op::Store, vec![arr, idx, val]))
}
/// Make a term with no arguments, just an operator.
pub fn leaf_term(op: Op) -> Term {
term(op, Vec::new())
}
/// Make a term with arguments.
#[track_caller]
pub fn term(op: Op, cs: Vec<Term>) -> Term {
mk(TermData { op, cs })
let t = mk(TermData { op, cs });
#[cfg(debug_assertions)]
check_rec(&t);
t
}
/// Make a bit-vector constant term.
@@ -1183,6 +1329,11 @@ where
))))
}
/// Make a bit-vector constant term.
pub fn bool_lit(b: bool) -> Term {
leaf_term(Op::Const(Value::Bool(b)))
}
#[macro_export]
/// Make a term.
///
@@ -1280,6 +1431,46 @@ impl ComputationMetadata {
self.input_vis.insert(input_name.clone(), party);
self.inputs.push(term);
}
/// Replace the `original` computation input with `new`, in the order given.
///
/// If the old input order was
///
/// w x y z x1 x2 x3
///
/// and `x` was replaced with `x1`, `x2`, `x3`, then the new input order is
///
/// w x1 x2 x3 y z
///
/// and other metadata associated with `x` is removed.
///
/// This is probably called after making the new inputs with `new_input`.
pub fn replace_input(
&mut self,
original: Term,
new: Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
) {
let mut i = self.inputs.iter().position(|t| t == &original).unwrap();
self.inputs.remove(i);
let name = if let Op::Var(n, _) = &original.op {
n.to_string()
} else {
unreachable!()
};
self.input_vis.remove(&name).unwrap();
for (input_name, sort, _, party) in new {
let term = leaf_term(Op::Var(input_name.clone(), sort));
debug_assert!(
!self.input_vis.contains_key(&input_name),
"Tried to create input {} (visibility {:?}), but it already existed (visibility {:?})",
input_name,
party,
self.input_vis.get(&input_name).unwrap()
);
self.input_vis.insert(input_name.clone(), party);
self.inputs.insert(i, term);
i += 1;
}
}
/// Returns None if the value is public. Otherwise, the unique party that knows it.
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
self.input_vis
@@ -1289,15 +1480,14 @@ impl ComputationMetadata {
}
/// Is this input public?
pub fn is_input(&self, input_name: &str) -> bool {
self.input_vis
.contains_key(input_name)
self.input_vis.contains_key(input_name)
}
/// Is this input public?
pub fn is_input_public(&self, input_name: &str) -> bool {
self.get_input_visibility(input_name).is_none()
}
/// Get all public inputs.
pub fn public_inputs(&self) -> impl Iterator<Item = &str> {
pub fn public_input_names(&self) -> impl Iterator<Item = &str> {
self.input_vis.iter().filter_map(|(name, party)| {
if party.is_none() {
Some(name.as_str())
@@ -1306,6 +1496,21 @@ impl ComputationMetadata {
}
})
}
/// Get all public inputs.
pub fn public_inputs<'a>(&'a self) -> impl Iterator<Item = Term> + 'a {
self.inputs.iter().filter_map(move |input| {
if let Op::Var(name, _) = &input.op {
let party = self.get_input_visibility(name);
if party.is_none() {
Some(input.clone())
} else {
None
}
} else {
unreachable!()
}
})
}
}
#[derive(Clone, Debug)]
@@ -1352,6 +1557,44 @@ impl Computation {
}
leaf_term(Op::Var(name.to_owned(), s))
}
/// Replace the `original` computation input with `new`, in the order given.
///
/// If the old input order was
///
/// w x y z
///
/// and `x` was replaced with `x1`, `x2`, `x3`, then the new input order is
///
/// w x1 x2 x3 y z
///
/// and other metadata associated with `x` is removed.
///
/// This is called in place of `new_var` during transformations.
pub fn replace_input(
&mut self,
original: Term,
mut new: Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
) {
if let Some(vs) = self.values.as_mut() {
if let Op::Var(name, _) = &original.op {
vs.remove(name);
for (name, _, val_opt, _) in &mut new {
vs.insert(name.clone(), std::mem::take(val_opt).unwrap());
}
}
}
self.metadata.replace_input(original, new);
}
/// Change the value associated with an input
pub fn map_value(&mut self, name: &str, f: impl FnOnce(Value) -> Value) {
if let Some(vs) = self.values.as_mut() {
let loc = vs.get_mut(name).unwrap();
let v = std::mem::replace(loc, Value::Bool(false));
*loc = f(v);
}
}
/// Create a new variable, `name` in the constraint system, and set it equal to `term`.
/// `public` indicates whether this variable is public in the constraint system.
pub fn assign(&mut self, name: &str, term: Term, party: Option<PartyId>) -> Term {
@@ -1364,7 +1607,7 @@ impl Computation {
/// Assert `s` in the system.
pub fn assert(&mut self, s: Term) {
assert!(check(&s) == Sort::Bool);
debug!("Assert: {}", s);
debug!("Assert: {}", extras::Letified(s.clone()));
self.outputs.push(s);
}
/// If tracking values, evaluate `term`, and set the result to `name`.
@@ -1383,41 +1626,21 @@ impl Computation {
Self {
outputs: Vec::new(),
metadata: ComputationMetadata::default(),
values: if values { Some(FxHashMap::default()) } else { None },
values: if values {
Some(FxHashMap::default())
} else {
None
},
}
}
// TODO: rm
// /// Make `s` a public input.
// pub fn publicize(&mut self, s: String) {
// self.public_inputs.insert(s);
// }
/// Get the outputs of the computation.
///
/// For proof systems, these are the assertions that must hold.
pub fn outputs(&self) -> &Vec<Term> {
&self.outputs
}
// TODO: rm
// /// Consume this system, yielding its parts: (assertions, public inputs, values)
// pub fn consume(self) -> (Vec<Term>, ComputationMetadata, Option<FxHashMap<String, Value>>) {
// (self.assertions, self.metadata, self.values)
// }
// /// Build a system from its parts: (assertions, public inputs, values)
// pub fn from_parts(
// assertions: Vec<Term>,
// public_inputs: FxHashSet<String>,
// values: Option<FxHashMap<String, Value>>,
// ) -> Self {
// Self {
// assertions,
// public_inputs,
// values,
// }
// }
// /// Get the term, (AND all assertions)
// pub fn assertions_as_term(&self) -> Term {
// term(Op::BoolNaryOp(BoolNaryOp::And), self.assertions.clone())
// }
/// How many total (unique) terms are there?
pub fn terms(&self) -> usize {
let mut terms = FxHashSet::<Term>::default();
@@ -1428,6 +1651,14 @@ impl Computation {
}
terms.len()
}
/// An iterator that visits each term in the computation, once.
pub fn terms_postorder(&self) -> impl Iterator<Item = Term> {
let mut terms: Vec<_> = PostOrderIter::new(term(Op::Tuple, self.outputs.clone())).collect();
// drop the top-level tuple term.
terms.pop();
terms.into_iter()
}
}
#[cfg(test)]

View File

@@ -85,11 +85,6 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
Op::PfUnOp(_) => Ok(check_raw(&t.cs[0])?),
Op::PfNaryOp(_) => Ok(check_raw(&t.cs[0])?),
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
Op::ConstArray(s, n) => Ok(Sort::Array(
Box::new(s.clone()),
Box::new(check_raw(&t.cs[0])?),
*n,
)),
Op::Select => array_or(&check_raw(&t.cs[0])?, "select").map(|(_, v)| v.clone()),
Op::Store => Ok(check_raw(&t.cs[0])?),
Op::Tuple => Ok(Sort::Tuple(
@@ -109,6 +104,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
)))
}
}
Op::Update(_i) => Ok(check_raw(&t.cs[0])?),
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
};
let mut term_tys = TERM_TYPES.write().unwrap();
@@ -260,13 +256,8 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
.and_then(|t| pf_or(t, ctx))
.map(|a| a.clone())
}
(Op::UbvToPf(m), &[a]) => {
bv_or(a, "sbv-to-fp").map(|_| Sort::Field(m.clone()))
}
(Op::UbvToPf(m), &[a]) => bv_or(a, "sbv-to-fp").map(|_| Sort::Field(m.clone())),
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
(Op::ConstArray(s, n), &[a]) => {
Ok(Sort::Array(Box::new(s.clone()), Box::new(a.clone()), *n))
}
(Op::Select, &[Sort::Array(k, v, _), a]) => {
eq_or(k, a, "select").map(|_| (**v).clone())
}
@@ -286,6 +277,17 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
)))
}
}),
(Op::Update(i), &[a, b]) => tuple_or(a, "tuple field update").and_then(|t| {
if i < &t.len() {
eq_or(&t[*i], b, "tuple update")?;
Ok(a.clone())
} else {
Err(TypeErrorReason::OutOfBounds(format!(
"index {} in tuple of sort {}",
i, a
)))
}
}),
(_, _) => Err(TypeErrorReason::Custom(format!("other"))),
})
.map_err(|reason| TypeError {
@@ -328,6 +330,8 @@ pub enum TypeErrorReason {
ExpectedPf(Sort, &'static str),
/// A sort should be an array
ExpectedArray(Sort, &'static str),
/// A sort should be a tuple
ExpectedTuple(&'static str),
/// An empty n-ary operator.
EmptyNary(String),
/// Something else
@@ -377,7 +381,7 @@ fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Vec<Sort>, TypeErrorReason> {
match a {
Sort::Tuple(a) => Ok(a),
_ => Err(TypeErrorReason::ExpectedPf(a.clone(), ctx)),
_ => Err(TypeErrorReason::ExpectedTuple(ctx)),
}
}

View File

@@ -150,9 +150,11 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
}
}
}
let terms: FxHashMap<Term, usize> = terms.into_iter().enumerate().map(|(i, t)| (t, i)).collect();
let terms: FxHashMap<Term, usize> =
terms.into_iter().enumerate().map(|(i, t)| (t, i)).collect();
let mut term_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default();
let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = FxHashMap::default();
let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> =
FxHashMap::default();
let mut ilp = Ilp::new();
// build variables for all term assignments

View File

@@ -188,7 +188,13 @@ impl ToABY {
fn remove_cons_gate(&self, circ: String) -> String {
if circ.contains("PutCONSGate(") {
circ.split("PutCONSGate(").last().unwrap_or("").split(",").next().unwrap_or("").to_string()
circ.split("PutCONSGate(")
.last()
.unwrap_or("")
.split(",")
.next()
.unwrap_or("")
.to_string()
} else {
panic!("PutCONSGate not found in: {}", circ)
}
@@ -510,9 +516,7 @@ impl ToABY {
_ => panic!("Invalid bv-op in BvBinOp: {:?}", o),
}
}
Op::BvExtract(_start, _end) => {
}
Op::BvExtract(_start, _end) => {}
_ => panic!("Non-field in embed_bv: {:?}", t),
}

View File

@@ -4,10 +4,10 @@ use ff::{PrimeField, PrimeFieldBits};
use gmp_mpfr_sys::gmp::limb_t;
use log::debug;
use std::collections::HashMap;
use std::path::Path;
use std::fs::File;
use std::str::FromStr;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::str::FromStr;
use super::*;
@@ -32,9 +32,15 @@ fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
zero_lc: LinearCombination<F>,
) -> LinearCombination<F> {
let mut lc_bellman = zero_lc;
lc_bellman = lc_bellman + (int_to_ff(&lc.constant), CS::one());
// This zero test is needed until https://github.com/zkcrypto/bellman/pull/78 is resolved
if lc.constant != 0 {
lc_bellman = lc_bellman + (int_to_ff(&lc.constant), CS::one());
}
for (v, c) in &lc.monomials {
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
// ditto
if c != &0 {
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
}
}
lc_bellman
}
@@ -87,7 +93,7 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
.get(&i)
.unwrap();
let ff_val = int_to_ff(i_val);
debug!("witness: {} -> {:?} ({})", s, ff_val, i_val);
debug!("value : {} -> {:?} ({})", s, ff_val, i_val);
ff_val
})
};
@@ -112,6 +118,11 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
|z| lc_to_bellman::<F, CS>(&vars, c, z),
);
}
debug!(
"done with synth: {} vars {} cs",
vars.len(),
self.constraints.len()
);
Ok(())
}
}
@@ -119,11 +130,13 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
/// Convert a (rug) integer to a prime field element.
pub fn parse_instance<P: AsRef<Path>, F: PrimeField>(path: P) -> Vec<F> {
let f = BufReader::new(File::open(path).unwrap());
f.lines().map(|line| {
let s = line.unwrap();
let i = Integer::from_str(&s.trim()).unwrap();
int_to_ff(&i)
}).collect()
f.lines()
.map(|line| {
let s = line.unwrap();
let i = Integer::from_str(&s.trim()).unwrap();
int_to_ff(&i)
})
.collect()
}
#[cfg(test)]

View File

@@ -251,7 +251,11 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
idxs_signals: HashMap::default(),
next_idx: 0,
public_idxs: HashSet::default(),
values: if values { Some(HashMap::default()) } else { None },
values: if values {
Some(HashMap::default())
} else {
None
},
constraints: Vec::new(),
}
}

View File

@@ -58,8 +58,13 @@ impl ToR1cs {
/// Get a new variable, with name dependent on `d`.
/// If values are being recorded, `value` must be provided.
fn fresh_var<D: Display + ?Sized>(&mut self, ctx: &D, value: Option<Integer>, public: bool) -> Lc {
let n = format!("{}_v{}", ctx, self.next_idx);
fn fresh_var<D: Display + ?Sized>(
&mut self,
ctx: &D,
value: Option<Integer>,
public: bool,
) -> Lc {
let n = format!("{}_n{}", ctx, self.next_idx);
self.next_idx += 1;
self.r1cs.add_signal(n.clone(), value);
if public {
@@ -97,7 +102,11 @@ impl ToR1cs {
}),
false,
);
let is_zero = self.fresh_var("is_zero", self.r1cs.eval(&x).map(|x| Integer::from(x == 0)), false);
let is_zero = self.fresh_var(
"is_zero",
self.r1cs.eval(&x).map(|x| Integer::from(x == 0)),
false,
);
self.r1cs.constraint(m, x.clone(), -is_zero.clone() + 1);
self.r1cs.constraint(is_zero.clone(), x, self.r1cs.zero());
is_zero
@@ -127,12 +136,15 @@ impl ToR1cs {
/// Evaluate `var`'s value as an (integer-casted) bit-vector.
/// Returns `None` if values are not stored.
fn eval_bv(&self, var: &str) -> Option<Integer> {
self.values
.as_ref()
.map(|vs| match vs.get(var).expect("missing value") {
self.values.as_ref().map(|vs| {
match vs
.get(var)
.unwrap_or_else(|| panic!("missing value for {}", var))
{
Value::BitVector(b) => b.uint().clone(),
v => panic!("{} should be a bit-vector, but is {:?}", var, v),
})
}
})
}
/// Evaluate `var`'s value as an (integer-casted) field element
@@ -343,6 +355,27 @@ impl ToR1cs {
}
}
fn assert_eq(&mut self, a: &Term, b: &Term) {
match check(a) {
Sort::Bool => {
let a = self.get_bool(a).clone();
let diff = a - self.get_bool(b);
self.assert_zero(diff);
}
Sort::BitVector(_) => {
let a = self.get_bv_uint(a);
let diff = a - &self.get_bv_uint(b);
self.assert_zero(diff);
}
Sort::Field(_) => {
let a = self.get_pf(a).clone();
let diff = a - self.get_pf(b);
self.assert_zero(diff);
}
s => panic!("Unimplemented sort for Eq: {:?}", s),
}
}
fn embed_bool(&mut self, c: Term) -> &Lc {
//println!("Embed: {}", c);
debug_assert!(check(&c) == Sort::Bool);
@@ -427,6 +460,23 @@ impl ToR1cs {
self.get_bool(&c)
}
fn assert_bool(&mut self, t: &Term) {
//println!("Embed: {}", c);
// TODO: skip if already embedded
if &t.op == &Op::Eq {
t.cs.iter().for_each(|c| self.embed(c.clone()));
self.assert_eq(&t.cs[0], &t.cs[1]);
} else if &t.op == &AND {
for c in &t.cs {
self.assert_bool(c);
}
} else {
self.embed(t.clone());
let lc = self.get_bool(&t).clone();
self.assert_zero(lc - 1);
}
}
/// Returns whether `a - b` fits in `size` non-negative bits.
/// i.e. is in `{0, 1, ..., 2^n-1}`.
fn bv_ge(&mut self, a: Lc, b: &Lc, size: usize) -> Lc {
@@ -848,9 +898,8 @@ impl ToR1cs {
}
fn assert(&mut self, t: Term) {
debug!("Assert: {}", Letified(t.clone()));
self.embed(t.clone());
let lc = self.get_bool(&t).clone();
self.assert_zero(lc - 1);
debug_assert!(check(&t) == Sort::Bool, "Non bool in assert");
self.assert_bool(&t);
}
}
@@ -861,7 +910,10 @@ pub fn to_r1cs(cs: Computation, modulus: Integer) -> R1cs<String> {
metadata,
values,
} = cs;
let public_inputs = metadata.public_inputs().map(ToOwned::to_owned).collect();
let public_inputs = metadata
.public_input_names()
.map(ToOwned::to_owned)
.collect();
debug!("public inputs: {:?}", public_inputs);
let mut converter = ToR1cs::new(modulus, values, public_inputs);
debug!(
@@ -871,7 +923,12 @@ pub fn to_r1cs(cs: Computation, modulus: Integer) -> R1cs<String> {
.map(|c| PostOrderIter::new(c.clone()).count())
.sum::<usize>()
);
println!("Printing assertions");
debug!("declaring inputs");
for i in metadata.public_inputs() {
debug!("input {}", i);
converter.embed(i);
}
debug!("Printing assertions");
for c in assertions {
converter.assert(c);
}
@@ -969,8 +1026,7 @@ pub mod test {
} else {
term![Op::Not; t]
};
let cs =
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
r1cs.check_all();
}
@@ -979,9 +1035,9 @@ pub mod test {
fn random_bool(ArbitraryTermEnv(t, values): ArbitraryTermEnv) {
let v = eval(&t, &values);
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs =
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
r1cs.check_all();
}
@@ -990,8 +1046,7 @@ pub mod test {
fn random_pure_bool_opt(ArbitraryBoolEnv(t, values): ArbitraryBoolEnv) {
let v = eval(&t, &values);
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs =
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
r1cs.check_all();
let r1cs2 = reduce_linearities(r1cs);
@@ -1002,9 +1057,9 @@ pub mod test {
fn random_bool_opt(ArbitraryTermEnv(t, values): ArbitraryTermEnv) {
let v = eval(&t, &values);
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs =
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
r1cs.check_all();
let r1cs2 = reduce_linearities(r1cs);
@@ -1016,9 +1071,7 @@ pub mod test {
let cs = Computation::from_constraint_system_parts(
vec![term![Op::Not; term![Op::Eq; bv(0b10110, 8),
term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]]],
vec![
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(8))),
],
vec![leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))],
Some(
vec![(
"b".to_owned(),
@@ -1041,8 +1094,7 @@ pub mod test {
.collect();
let v = eval(&t, &values);
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs =
Computation::from_constraint_system_parts(vec![t], vec![], Some(values));
let cs = Computation::from_constraint_system_parts(vec![t], vec![], Some(values));
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
r1cs.check_all();
let r1cs2 = reduce_linearities(r1cs);
@@ -1170,7 +1222,7 @@ pub mod test {
#[test]
fn tuple() {
let cs = Computation::from_constraint_system_parts(
let mut cs = Computation::from_constraint_system_parts(
vec![
term![Op::Field(0); term![Op::Tuple; leaf_term(Op::Var("a".to_owned(), Sort::Bool)), leaf_term(Op::Const(Value::Bool(false)))]],
term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))],
@@ -1188,7 +1240,7 @@ pub mod test {
.collect(),
),
);
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let r1cs = to_r1cs(cs, Integer::from(17));
r1cs.check_all();
}

View File

@@ -71,14 +71,21 @@ impl Expr2Smt<()> for Value {
}
write!(w, ")")?;
}
Value::Array(s, default, map, _size) => {
Value::Array(Array {
key_sort,
default,
map,
size,
}) => {
for _ in 0..map.len() {
write!(w, "(store ")?;
}
let val_s = check(&leaf_term(Op::Const((**default).clone())));
let s = Sort::Array(Box::new(key_sort.clone()), Box::new(val_s), *size);
write!(
w,
"((as const {}) {})",
SmtSortDisp(&*s),
SmtSortDisp(&s),
SmtDisp(&**default)
)?;
for (k, v) in map {

View File

@@ -23,7 +23,6 @@ mod hash_test {
assert_eq!(v1, v2);
}
#[test]
fn test_map_non_det_iter_order() {
let mut m1: HashMap<usize, usize> = HashMap::new();