mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
IR-based Zokrates front-end (#33)
The ZoKrates front-end now represents ZoK arrays as IR arrays, and ZoK structures as (type-tagged) IR tuples.
During this change, I discovered that IR support for eliminating tuples and arrays was not complete.
Thus the change list is:
The ZoK front-end uses IR arrays and tuples
Improve IR passes for array and tuple elimination
Enforce cargo fmt in CI
Bugfix: handle ZoK accessors in L-values in the correct order
Bugfix: add array evaluation to the IR
This PR does not:
implement an array flattening pass
implement permutation-based memory-checking
Benefits:
The ZoK->R1CS compiler is now ~5.88x faster (as defined by the time it takes to run the tests in master's scripts/zokrates_test.zsh script: this goes from 8.59s to 1.46s)
For benchmarks with multi-dimensional arrays, the ZoK->R1CS compiler can now compile them with reasonable speed. Before it it would time out on even tiny examples.
The ZoK->R1CS compiler will be able to benefit from future memory-checking improvements
IR support for arrays and tuples is complete now, making those parts of the IR more accessible to future front-ends.
alex-ozdemir added 21 commits 11 days ago
This commit is contained in:
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -25,8 +25,10 @@ jobs:
|
||||
- uses: Swatinem/rust-cache@v1
|
||||
- name: Initialize submodules
|
||||
run: make init
|
||||
- name: Check
|
||||
- name: Typecheck
|
||||
run: cargo check --verbose
|
||||
- name: Check format
|
||||
run: cargo fmt -- --check
|
||||
- name: Build
|
||||
run: cargo build --verbose && make build
|
||||
- name: Run tests
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,10 +1,11 @@
|
||||
/target
|
||||
/pf
|
||||
assignment.txt
|
||||
params
|
||||
/assignment.txt
|
||||
/params
|
||||
__pycache__
|
||||
/P
|
||||
/V
|
||||
/pi
|
||||
/x
|
||||
perf.data*
|
||||
/perf.data*
|
||||
/.gdb_history
|
||||
|
||||
12
Cargo.lock
generated
12
Cargo.lock
generated
@@ -217,6 +217,7 @@ dependencies = [
|
||||
"good_lp",
|
||||
"hashconsing",
|
||||
"ieee754",
|
||||
"itertools 0.10.3",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"lp-solvers",
|
||||
@@ -531,6 +532,15 @@ dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.8"
|
||||
@@ -657,7 +667,7 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fbf404899169771dd6a32c84248b83cd67a26cc7cc957aac87661490e1227e4"
|
||||
dependencies = [
|
||||
"itertools",
|
||||
"itertools 0.7.11",
|
||||
"proc-macro2 0.4.30",
|
||||
"quote 0.6.13",
|
||||
"single",
|
||||
|
||||
@@ -32,6 +32,7 @@ pest = "2.1"
|
||||
pest_derive = "2.1"
|
||||
pest-ast = "0.3"
|
||||
from-pest = "0.3"
|
||||
itertools = "0.10"
|
||||
|
||||
[dev-dependencies]
|
||||
quickcheck = "1"
|
||||
|
||||
5
Makefile
5
Makefile
@@ -4,7 +4,7 @@ build: init
|
||||
cargo build --release --example circ && ./scripts/build_mpc_zokrates_test.zsh && ./scripts/build_aby.zsh
|
||||
|
||||
test:
|
||||
cargo test && ./scripts/zokrates_test.zsh && python3 ./scripts/test_aby.py && ./scripts/test_zok_to_ilp.zsh && ./scripts/test_zok_to_ilp_pf.zsh ./scripts/test_atalog.zsh
|
||||
cargo test && ./scripts/zokrates_test.zsh && python3 ./scripts/test_aby.py && ./scripts/test_zok_to_ilp.zsh && ./scripts/test_zok_to_ilp_pf.zsh && ./scripts/test_datalog.zsh
|
||||
|
||||
init:
|
||||
git submodule update --init
|
||||
@@ -18,3 +18,6 @@ clean:
|
||||
touch ./third_party/ABY/src/examples/2pc_* && rm -r -- ./third_party/ABY/src/examples/2pc_*
|
||||
sed '/add_subdirectory.*2pc.*/d' -i ./third_party/ABY/src/examples/CMakeLists.txt
|
||||
rm -rf P V pi perf.data perf.data.old flamegraph.svg
|
||||
|
||||
format:
|
||||
cargo fmt --all
|
||||
|
||||
6
TODO.md
Normal file
6
TODO.md
Normal file
@@ -0,0 +1,6 @@
|
||||
Passes to write:
|
||||
|
||||
[ ] shrink bit-vectors using range analysis.
|
||||
[ ] common sub-expression grouping
|
||||
* for commutative/associative ops?
|
||||
* after flattening
|
||||
14
examples/ZoKrates/pf/arr_str_arr_str.zok
Normal file
14
examples/ZoKrates/pf/arr_str_arr_str.zok
Normal file
@@ -0,0 +1,14 @@
|
||||
struct Pt {
|
||||
field x
|
||||
field y
|
||||
}
|
||||
struct Pts {
|
||||
Pt[2] pts
|
||||
}
|
||||
|
||||
def main(private field y) -> field:
|
||||
Pt p1 = Pt {x: 2, y: y}
|
||||
Pt p2 = Pt {x: y, y: 2}
|
||||
Pts[1] pts = [Pts { pts: [p1, p2] }]
|
||||
return pts[0].pts[0].y * pts[0].pts[1].x
|
||||
|
||||
1
examples/ZoKrates/pf/arr_str_arr_str.zok.in
Normal file
1
examples/ZoKrates/pf/arr_str_arr_str.zok.in
Normal file
@@ -0,0 +1 @@
|
||||
y 4
|
||||
1
examples/ZoKrates/pf/arr_str_arr_str.zok.x
Normal file
1
examples/ZoKrates/pf/arr_str_arr_str.zok.x
Normal file
@@ -0,0 +1 @@
|
||||
16
|
||||
3
examples/ZoKrates/pf/many_pub.zok
Normal file
3
examples/ZoKrates/pf/many_pub.zok
Normal file
@@ -0,0 +1,3 @@
|
||||
// Making sure we get input order right
|
||||
def main(public u16 a, public u16 b, public u16 c, public u16 d) -> u16:
|
||||
return a ^ b ^ c ^ d
|
||||
4
examples/ZoKrates/pf/many_pub.zok.in
Normal file
4
examples/ZoKrates/pf/many_pub.zok.in
Normal file
@@ -0,0 +1,4 @@
|
||||
a 1
|
||||
b 2
|
||||
c 3
|
||||
d 4
|
||||
5
examples/ZoKrates/pf/many_pub.zok.x
Normal file
5
examples/ZoKrates/pf/many_pub.zok.x
Normal file
@@ -0,0 +1,5 @@
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
4
|
||||
12
examples/ZoKrates/pf/mm.zok
Normal file
12
examples/ZoKrates/pf/mm.zok
Normal file
@@ -0,0 +1,12 @@
|
||||
def main(private field[2][2] A, private field[2][2] B) -> field[2][2]:
|
||||
field [2][2] AB = [[0; 2]; 2]
|
||||
for field i in 0..2 do
|
||||
for field j in 0..2 do
|
||||
for field k in 0..2 do
|
||||
AB[i][j] = AB[i][j] + A[i][k] * B[k][j]
|
||||
endfor
|
||||
endfor
|
||||
endfor
|
||||
return AB
|
||||
|
||||
|
||||
8
examples/ZoKrates/pf/mm.zok.in
Normal file
8
examples/ZoKrates/pf/mm.zok.in
Normal file
@@ -0,0 +1,8 @@
|
||||
A.0.0 1
|
||||
A.0.1 0
|
||||
A.1.0 0
|
||||
A.1.1 1
|
||||
B.0.0 1
|
||||
B.0.1 0
|
||||
B.1.0 0
|
||||
B.1.1 1
|
||||
4
examples/ZoKrates/pf/mm.zok.x
Normal file
4
examples/ZoKrates/pf/mm.zok.x
Normal file
@@ -0,0 +1,4 @@
|
||||
1
|
||||
0
|
||||
0
|
||||
1
|
||||
2
examples/ZoKrates/pf/mul.zok
Normal file
2
examples/ZoKrates/pf/mul.zok
Normal file
@@ -0,0 +1,2 @@
|
||||
def main(private field x, private field y)-> field:
|
||||
return x * y
|
||||
2
examples/ZoKrates/pf/mul.zok.in
Normal file
2
examples/ZoKrates/pf/mul.zok.in
Normal file
@@ -0,0 +1,2 @@
|
||||
x 4
|
||||
y 5
|
||||
1
examples/ZoKrates/pf/mul.zok.x
Normal file
1
examples/ZoKrates/pf/mul.zok.x
Normal file
@@ -0,0 +1 @@
|
||||
20
|
||||
12
examples/ZoKrates/pf/str_arr_str.zok
Normal file
12
examples/ZoKrates/pf/str_arr_str.zok
Normal file
@@ -0,0 +1,12 @@
|
||||
struct Pt {
|
||||
field x
|
||||
field y
|
||||
}
|
||||
struct Pts {
|
||||
Pt[2] pts
|
||||
}
|
||||
|
||||
def main(field y) -> field:
|
||||
Pt p = Pt {x: 2, y: y}
|
||||
Pts pts = Pts { pts: [p, p] }
|
||||
return pts.pts[0].y + pts.pts[1].x
|
||||
1
examples/ZoKrates/pf/str_arr_str.zok.in
Normal file
1
examples/ZoKrates/pf/str_arr_str.zok.in
Normal file
@@ -0,0 +1 @@
|
||||
y 6
|
||||
2
examples/ZoKrates/pf/str_arr_str.zok.x
Normal file
2
examples/ZoKrates/pf/str_arr_str.zok.x
Normal file
@@ -0,0 +1,2 @@
|
||||
6
|
||||
8
|
||||
10
examples/ZoKrates/pf/str_str.zok
Normal file
10
examples/ZoKrates/pf/str_str.zok
Normal file
@@ -0,0 +1,10 @@
|
||||
struct Pt {
|
||||
field x
|
||||
field y
|
||||
}
|
||||
struct PtWr {
|
||||
Pt p
|
||||
}
|
||||
def main(field x, field y) -> field:
|
||||
PtWr p = PtWr { p: Pt { x: x, y: y } }
|
||||
return p.p.x * p.p.y
|
||||
2
examples/ZoKrates/pf/str_str.zok.in
Normal file
2
examples/ZoKrates/pf/str_str.zok.in
Normal file
@@ -0,0 +1,2 @@
|
||||
x 5
|
||||
y 6
|
||||
3
examples/ZoKrates/pf/str_str.zok.x
Normal file
3
examples/ZoKrates/pf/str_str.zok.x
Normal file
@@ -0,0 +1,3 @@
|
||||
5
|
||||
6
|
||||
30
|
||||
13
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok
Normal file
13
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok
Normal file
@@ -0,0 +1,13 @@
|
||||
struct Pt {
|
||||
field x
|
||||
field y
|
||||
}
|
||||
struct Pts {
|
||||
Pt[2] pts
|
||||
}
|
||||
|
||||
def main(private field y, private field i, private field j, private field k) -> field:
|
||||
Pt p = Pt {x: y, y: y}
|
||||
Pts[1] pts = [Pts { pts: [p, p] }]
|
||||
return pts[i].pts[j].y * pts[i].pts[j].x
|
||||
|
||||
4
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok.in
Normal file
4
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok.in
Normal file
@@ -0,0 +1,4 @@
|
||||
y 6
|
||||
i 0
|
||||
j 0
|
||||
k 1
|
||||
1
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok.x
Normal file
1
examples/ZoKrates/pf/var_idx_arr_str_arr_str.zok.x
Normal file
@@ -0,0 +1 @@
|
||||
36
|
||||
@@ -177,24 +177,31 @@ fn main() {
|
||||
}
|
||||
};
|
||||
let cs = match mode {
|
||||
Mode::Opt => opt(cs, vec![Opt::ConstantFold]),
|
||||
Mode::Opt => opt(cs, vec![Opt::ScalarizeVars, Opt::ConstantFold]),
|
||||
Mode::Mpc(_) => opt(
|
||||
cs,
|
||||
vec![],
|
||||
vec![Opt::ScalarizeVars],
|
||||
// vec![Opt::Sha, Opt::ConstantFold, Opt::Mem, Opt::ConstantFold],
|
||||
),
|
||||
Mode::Proof | Mode::ProofOfHighValue(_) => opt(
|
||||
cs,
|
||||
vec![
|
||||
Opt::ScalarizeVars,
|
||||
Opt::Flatten,
|
||||
Opt::Sha,
|
||||
Opt::ConstantFold,
|
||||
Opt::Flatten,
|
||||
//Opt::FlattenAssertions,
|
||||
Opt::Inline,
|
||||
Opt::Mem,
|
||||
// Tuples must be eliminated before oblivious array elim
|
||||
Opt::Tuple,
|
||||
Opt::ConstantFold,
|
||||
Opt::Obliv,
|
||||
// The obliv elim pass produces more tuples, that must be eliminated
|
||||
Opt::Tuple,
|
||||
Opt::LinearScan,
|
||||
// The linear scan pass produces more tuples, that must be eliminated
|
||||
Opt::Tuple,
|
||||
Opt::Flatten,
|
||||
//Opt::FlattenAssertions,
|
||||
Opt::ConstantFold,
|
||||
Opt::Inline,
|
||||
],
|
||||
@@ -215,12 +222,12 @@ fn main() {
|
||||
let r1cs = to_r1cs(cs, circ::front::zokrates::ZOKRATES_MODULUS.clone());
|
||||
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
|
||||
let r1cs = reduce_linearities(r1cs);
|
||||
println!("Final R1cs size: {}", r1cs.constraints().len());
|
||||
match action {
|
||||
ProofAction::Count => {
|
||||
println!("Final R1cs size: {}", r1cs.constraints().len());
|
||||
}
|
||||
ProofAction::Count => (),
|
||||
ProofAction::Prove => {
|
||||
println!("Proving");
|
||||
r1cs.check_all();
|
||||
let rng = &mut rand::thread_rng();
|
||||
let mut pk_file = File::open(prover_key).unwrap();
|
||||
let pk = Parameters::<Bls12>::read(&mut pk_file, false).unwrap();
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
|
||||
set -ex
|
||||
|
||||
cargo flamegraph --help || (echo "Please install the rust 'flamegraph' binary with 'cargo install flamegraph'" && exit 1)
|
||||
(cargo flamegraph --help > /dev/null) || (echo "Please install the rust 'flamegraph' binary with 'cargo install flamegraph'" && exit 1)
|
||||
|
||||
cargo flamegraph --example circ third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok r1cs --action count
|
||||
|
||||
@@ -15,10 +15,14 @@ $BIN --language datalog ./examples/datalog/arr.pl r1cs --action count || true
|
||||
# Small R1cs b/c too little recursion.
|
||||
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 4 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
|
||||
[ "$size" -lt 10 ]
|
||||
|
||||
# Big R1cs b/c enough recursion
|
||||
($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 5 r1cs --action count || true) | egrep "Final R1cs size: 356"
|
||||
($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 10 r1cs --action count || true) | egrep "Final R1cs size: 356"
|
||||
($BIN --language datalog ./examples/datalog/dec.pl -r 2 r1cs --action count || true) | egrep "Final R1cs size: 356"
|
||||
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 5 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
|
||||
[ "$size" -gt 250 ]
|
||||
size=$(($BIN --language datalog ./examples/datalog/dumb_hash.pl -r 10 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
|
||||
[ "$size" -gt 250 ]
|
||||
size=$(($BIN --language datalog ./examples/datalog/dec.pl -r 2 r1cs --action count || true) | egrep "Final R1cs size:" | egrep -o "\\b[0-9]+")
|
||||
[ "$size" -gt 250 ]
|
||||
|
||||
# Test prim-rec test
|
||||
$BIN --language datalog ./examples/datalog/dec.pl --lint-prim-rec smt
|
||||
|
||||
@@ -5,7 +5,9 @@ set -ex
|
||||
disable -r time
|
||||
|
||||
cargo build --release --example circ
|
||||
#cargo build --example circ
|
||||
|
||||
#BIN=./target/debug/examples/circ
|
||||
BIN=./target/release/examples/circ
|
||||
|
||||
case "$OSTYPE" in
|
||||
@@ -22,6 +24,15 @@ function r1cs_test {
|
||||
measure_time $BIN $zpath r1cs --action count
|
||||
}
|
||||
|
||||
# Test prove workflow, given an example name
|
||||
function pf_test {
|
||||
ex_name=$1
|
||||
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
|
||||
$BIN --inputs examples/ZoKrates/pf/$ex_name.zok.in examples/ZoKrates/pf/$ex_name.zok r1cs --action prove
|
||||
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --instance examples/ZoKrates/pf/$ex_name.zok.x --action verify
|
||||
rm -rf P V pi
|
||||
}
|
||||
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok
|
||||
@@ -36,15 +47,12 @@ r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zo
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok
|
||||
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok
|
||||
|
||||
|
||||
# Test prove workflow, given an example name
|
||||
function pf_test {
|
||||
ex_name=$1
|
||||
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
|
||||
$BIN --inputs examples/ZoKrates/pf/$ex_name.zok.in examples/ZoKrates/pf/$ex_name.zok r1cs --action prove
|
||||
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --instance examples/ZoKrates/pf/$ex_name.zok.x --action verify
|
||||
rm -rf P V pi
|
||||
}
|
||||
|
||||
pf_test 3_plus
|
||||
pf_test xor
|
||||
pf_test mul
|
||||
pf_test many_pub
|
||||
pf_test str_str
|
||||
pf_test str_arr_str
|
||||
pf_test arr_str_arr_str
|
||||
pf_test var_idx_arr_str_arr_str
|
||||
pf_test mm
|
||||
|
||||
@@ -127,17 +127,12 @@ impl MemManager {
|
||||
///
|
||||
/// Returns a (concrete) allocation identifier which can be used to access this allocation.
|
||||
pub fn zero_allocate(&mut self, size: usize, addr_width: usize, val_width: usize) -> AllocId {
|
||||
let sort = Sort::Array(
|
||||
Box::new(Sort::BitVector(addr_width)),
|
||||
Box::new(Sort::BitVector(val_width)),
|
||||
size,
|
||||
);
|
||||
let array = Value::Array(
|
||||
sort,
|
||||
let array = Value::Array(Array::new(
|
||||
Sort::BitVector(addr_width),
|
||||
Box::new(Value::BitVector(BitVector::zeros(val_width))),
|
||||
BTreeMap::new(),
|
||||
size,
|
||||
);
|
||||
));
|
||||
self.allocate(leaf_term(Op::Const(array)))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::term::T;
|
||||
use super::parser::ast::Span;
|
||||
use super::term::T;
|
||||
|
||||
use std::fmt::{Display, Formatter, self};
|
||||
use std::convert::From;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// An error in circuit translation
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! Datalog parser
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use pest::Parser;
|
||||
use pest::error::Error;
|
||||
use pest::Parser;
|
||||
use pest_derive::Parser;
|
||||
|
||||
// Issue with the proc macro
|
||||
@@ -193,7 +193,6 @@ pub mod ast {
|
||||
pub span: Span<'ast>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct BinaryExpression<'ast> {
|
||||
pub op: BinaryOperator,
|
||||
@@ -314,17 +313,20 @@ pub mod ast {
|
||||
match next.as_rule() {
|
||||
// this happens when we have an expression in parentheses: it needs to be processed as another sequence of terms and operators
|
||||
Rule::paren_expr => Expression::Paren(
|
||||
Box::new(Expression::from_pest(
|
||||
Box::new(
|
||||
Expression::from_pest(
|
||||
&mut pair.into_inner().next().unwrap().into_inner(),
|
||||
).unwrap()),
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
next.as_span(),
|
||||
),
|
||||
Rule::literal => {
|
||||
Expression::Literal(Literal::from_pest(&mut pair.into_inner()).unwrap())
|
||||
}
|
||||
Rule::identifier => Expression::Identifier(
|
||||
Ident::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
),
|
||||
Rule::identifier => {
|
||||
Expression::Identifier(Ident::from_pest(&mut pair.into_inner()).unwrap())
|
||||
}
|
||||
Rule::unary_expression => {
|
||||
let span = next.as_span();
|
||||
let mut inner = next.into_inner();
|
||||
@@ -345,9 +347,9 @@ pub mod ast {
|
||||
span,
|
||||
})
|
||||
}
|
||||
Rule::call_expr => Expression::Call(
|
||||
CallExpression::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
),
|
||||
Rule::call_expr => {
|
||||
Expression::Call(CallExpression::from_pest(&mut pair.into_inner()).unwrap())
|
||||
}
|
||||
Rule::access_expr => Expression::Access(
|
||||
AccessExpression::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
),
|
||||
|
||||
@@ -9,7 +9,7 @@ use super::error::ErrorKind;
|
||||
use super::ty::Ty;
|
||||
|
||||
use crate::circify::{CirCtx, Embeddable};
|
||||
use crate::front::zokrates::ZOKRATES_MODULUS_ARC;
|
||||
use crate::front::zokrates::{ZOKRATES_MODULUS_ARC, ZOK_FIELD_SORT};
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// A term
|
||||
@@ -81,19 +81,22 @@ pub fn uint_lit(v: u64, w: u8) -> T {
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
fn default(&self) -> T {
|
||||
T::new(self.default_ir(), self.clone())
|
||||
}
|
||||
fn default_ir(&self) -> Term {
|
||||
fn sort(&self) -> Sort {
|
||||
match self {
|
||||
Self::Bool => leaf_term(Op::Const(Value::Bool(false))),
|
||||
Self::Uint(w) => bv_lit(0, *w as usize),
|
||||
Self::Field => pf_ir_lit(0),
|
||||
Self::Array(l, t) => {
|
||||
term![Op::ConstArray(Sort::Field(ZOKRATES_MODULUS_ARC.clone()), *l); t.default_ir()]
|
||||
Self::Bool => Sort::Bool,
|
||||
Self::Uint(w) => Sort::BitVector(*w as usize),
|
||||
Self::Field => ZOK_FIELD_SORT.clone(),
|
||||
Self::Array(n, b) => {
|
||||
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn default_ir_term(&self) -> Term {
|
||||
self.sort().default_term()
|
||||
}
|
||||
fn default(&self) -> T {
|
||||
T::new(self.default_ir_term(), self.clone())
|
||||
}
|
||||
}
|
||||
impl Display for T {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
@@ -397,7 +400,7 @@ impl Embeddable for Datalog {
|
||||
})
|
||||
.enumerate()
|
||||
.fold(
|
||||
ty.default_ir(),
|
||||
ty.default_ir_term(),
|
||||
|arr, (i, v)| term![Op::Store; arr, pf_ir_lit(i), v.ir],
|
||||
),
|
||||
ty.clone(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Types in our datalog variant
|
||||
|
||||
use std::fmt::{Display, Formatter, self};
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
/// A type
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
|
||||
@@ -6,6 +6,7 @@ mod term;
|
||||
use super::FrontEnd;
|
||||
use crate::circify::{Circify, Loc, Val};
|
||||
use crate::ir::proof::{self, ConstraintMetadata};
|
||||
use crate::ir::term::extras::Letified;
|
||||
use crate::ir::term::*;
|
||||
use log::debug;
|
||||
use rug::Integer;
|
||||
@@ -21,6 +22,8 @@ use term::*;
|
||||
pub use term::ZOKRATES_MODULUS;
|
||||
/// The modulus for the ZoKrates language.
|
||||
pub use term::ZOKRATES_MODULUS_ARC;
|
||||
/// The modulus for the ZoKrates language.
|
||||
pub use term::ZOK_FIELD_SORT;
|
||||
|
||||
/// The prover visibility
|
||||
pub const PROVER_VIS: Option<PartyId> = Some(proof::PROVER_ID);
|
||||
@@ -99,18 +102,28 @@ struct ZGen<'ast> {
|
||||
mode: Mode,
|
||||
}
|
||||
|
||||
enum ZLoc {
|
||||
Var(Loc),
|
||||
Member(Box<ZLoc>, String),
|
||||
Idx(Box<ZLoc>, T),
|
||||
struct ZLoc {
|
||||
var: Loc,
|
||||
accesses: Vec<ZAccess>,
|
||||
}
|
||||
|
||||
impl ZLoc {
|
||||
fn loc(&self) -> &Loc {
|
||||
match self {
|
||||
ZLoc::Var(l) => l,
|
||||
ZLoc::Member(i, _) => i.loc(),
|
||||
ZLoc::Idx(i, _) => i.loc(),
|
||||
enum ZAccess {
|
||||
Member(String),
|
||||
Idx(T),
|
||||
}
|
||||
|
||||
fn loc_store(struct_: T, loc: &[ZAccess], val: T) -> Result<T, String> {
|
||||
match loc.first() {
|
||||
None => Ok(val),
|
||||
Some(ZAccess::Member(field)) => {
|
||||
let inner = field_select(&struct_, &field)?;
|
||||
let new_inner = loc_store(inner, &loc[1..], val)?;
|
||||
field_store(struct_, &field, new_inner)
|
||||
}
|
||||
Some(ZAccess::Idx(idx)) => {
|
||||
let old_inner = array_select(struct_.clone(), idx.clone())?;
|
||||
let new_inner = loc_store(old_inner, &loc[1..], val)?;
|
||||
array_store(struct_, idx.clone(), new_inner)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -191,17 +204,16 @@ impl<'ast> ZGen<'ast> {
|
||||
self.unwrap(decl_res, &i.index.span);
|
||||
for j in s..e {
|
||||
self.circ.enter_scope();
|
||||
let ass_res = self
|
||||
.circ
|
||||
.assign(Loc::local(v_name.clone()), Val::Term(
|
||||
match ty {
|
||||
Ty::Uint(8) => T::Uint(8, bv_lit(j, 8)),
|
||||
Ty::Uint(16) => T::Uint(16, bv_lit(j, 16)),
|
||||
Ty::Uint(32) => T::Uint(32, bv_lit(j, 32)),
|
||||
Ty::Field => T::Field(pf_lit(j)),
|
||||
_ => panic!("Unexpected type for iteration: {:?}", ty),
|
||||
}
|
||||
));
|
||||
let ass_res = self.circ.assign(
|
||||
Loc::local(v_name.clone()),
|
||||
Val::Term(match ty {
|
||||
Ty::Uint(8) => uint_lit(j, 8),
|
||||
Ty::Uint(16) => uint_lit(j, 16),
|
||||
Ty::Uint(32) => uint_lit(j, 32),
|
||||
Ty::Field => field_lit(j),
|
||||
_ => panic!("Unexpected type for iteration: {:?}", ty),
|
||||
}),
|
||||
);
|
||||
self.unwrap(ass_res, &i.index.span);
|
||||
for s in &i.statements {
|
||||
self.stmt(s);
|
||||
@@ -217,7 +229,7 @@ impl<'ast> ZGen<'ast> {
|
||||
let ty = e.type_();
|
||||
if let Some(t) = l.ty.as_ref() {
|
||||
let decl_ty = self.type_(t);
|
||||
if decl_ty != ty {
|
||||
if &decl_ty != ty {
|
||||
self.err(
|
||||
format!(
|
||||
"Assignment type mismatch: {} annotated vs {} actual",
|
||||
@@ -242,79 +254,60 @@ impl<'ast> ZGen<'ast> {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_lval_mod(&mut self, base: T, loc: ZLoc, val: T) -> Result<T, String> {
|
||||
match loc {
|
||||
ZLoc::Var(_) => Ok(val),
|
||||
ZLoc::Member(inner_loc, field) => {
|
||||
let old_inner = field_select(&base, &field)?;
|
||||
let new_inner = self.apply_lval_mod(old_inner, *inner_loc, val)?;
|
||||
field_store(base, &field, new_inner)
|
||||
}
|
||||
ZLoc::Idx(inner_loc, idx) => {
|
||||
let old_inner = array_select(base.clone(), idx.clone())?;
|
||||
let new_inner = self.apply_lval_mod(old_inner, *inner_loc, val)?;
|
||||
array_store(base, idx, new_inner)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mod_lval(&mut self, l: ZLoc, t: T) -> Result<(), String> {
|
||||
let var = l.loc().clone();
|
||||
fn mod_lval(&mut self, loc: ZLoc, val: T) -> Result<(), String> {
|
||||
let old = self
|
||||
.circ
|
||||
.get_value(var.clone())
|
||||
.get_value(loc.var.clone())
|
||||
.map_err(|e| format!("{}", e))?
|
||||
.unwrap_term();
|
||||
let new = self.apply_lval_mod(old, l, t)?;
|
||||
let new = loc_store(old, &loc.accesses, val)?;
|
||||
debug!("Assign: {:?} = {}", loc.var, Letified(new.term.clone()));
|
||||
self.circ
|
||||
.assign(var, Val::Term(new))
|
||||
.assign(loc.var, Val::Term(new))
|
||||
.map_err(|e| format!("{}", e))
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
fn lval(&mut self, l: &ast::Assignee<'ast>) -> ZLoc {
|
||||
l.accesses.iter().fold(
|
||||
ZLoc::Var(Loc::local(l.id.value.clone())),
|
||||
|inner, acc| match acc {
|
||||
ast::AssigneeAccess::Member(m) => ZLoc::Member(Box::new(inner), m.id.value.clone()),
|
||||
ast::AssigneeAccess::Select(m) => {
|
||||
let i = if let ast::RangeOrExpression::Expression(e) = &m.expression {
|
||||
let mut loc = ZLoc {
|
||||
var: Loc::local(l.id.value.clone()),
|
||||
accesses: vec![],
|
||||
};
|
||||
for acc in &l.accesses {
|
||||
loc.accesses.push(match acc {
|
||||
ast::AssigneeAccess::Member(m) => ZAccess::Member(m.id.value.clone()),
|
||||
ast::AssigneeAccess::Select(m) => ZAccess::Idx(
|
||||
if let ast::RangeOrExpression::Expression(e) = &m.expression {
|
||||
self.expr(&e)
|
||||
} else {
|
||||
panic!("Cannot assign to slice")
|
||||
};
|
||||
ZLoc::Idx(Box::new(inner), i)
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
),
|
||||
})
|
||||
}
|
||||
loc
|
||||
}
|
||||
|
||||
fn const_(&mut self, e: &ast::ConstantExpression<'ast>) -> T {
|
||||
match e {
|
||||
ast::ConstantExpression::U8(u) => {
|
||||
T::Uint(8, bv_lit(u8::from_str_radix(&u.value[2..], 16).unwrap(), 8))
|
||||
uint_lit(u8::from_str_radix(&u.value[2..], 16).unwrap(), 8)
|
||||
}
|
||||
ast::ConstantExpression::U16(u) => {
|
||||
uint_lit(u16::from_str_radix(&u.value[2..], 16).unwrap(), 16)
|
||||
}
|
||||
ast::ConstantExpression::U32(u) => {
|
||||
uint_lit(u32::from_str_radix(&u.value[2..], 16).unwrap(), 32)
|
||||
}
|
||||
ast::ConstantExpression::U16(u) => T::Uint(
|
||||
16,
|
||||
bv_lit(u16::from_str_radix(&u.value[2..], 16).unwrap(), 16),
|
||||
),
|
||||
ast::ConstantExpression::U32(u) => T::Uint(
|
||||
32,
|
||||
bv_lit(u32::from_str_radix(&u.value[2..], 16).unwrap(), 32),
|
||||
),
|
||||
ast::ConstantExpression::DecimalNumber(u) => {
|
||||
T::Field(pf_lit(Integer::from_str_radix(&u.value, 10).unwrap()))
|
||||
field_lit(Integer::from_str_radix(&u.value, 10).unwrap())
|
||||
}
|
||||
ast::ConstantExpression::BooleanLiteral(u) => {
|
||||
Self::const_bool(bool::from_str(&u.value).unwrap())
|
||||
z_bool_lit(bool::from_str(&u.value).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn const_bool(b: bool) -> T {
|
||||
T::Bool(leaf_term(Op::Const(Value::Bool(b))))
|
||||
}
|
||||
|
||||
fn bin_op(&self, o: &ast::BinaryOperator) -> fn(T, T) -> Result<T, String> {
|
||||
match o {
|
||||
ast::BinaryOperator::BitXor => bitxor,
|
||||
@@ -364,7 +357,7 @@ impl<'ast> ZGen<'ast> {
|
||||
.flat_map(|x| self.array_lit_elem(x))
|
||||
.collect(),
|
||||
),
|
||||
ast::Expression::InlineStruct(u) => Ok(T::Struct(
|
||||
ast::Expression::InlineStruct(u) => Ok(T::new_struct(
|
||||
u.ty.value.clone(),
|
||||
u.members
|
||||
.iter()
|
||||
@@ -373,12 +366,11 @@ impl<'ast> ZGen<'ast> {
|
||||
)),
|
||||
ast::Expression::ArrayInitializer(a) => {
|
||||
let v = self.expr(&a.value);
|
||||
let ty = v.type_();
|
||||
let n = const_int(self.const_(&a.count))
|
||||
.unwrap()
|
||||
.to_usize()
|
||||
.unwrap();
|
||||
Ok(T::Array(ty, vec![v; n]))
|
||||
array(vec![v; n])
|
||||
}
|
||||
ast::Expression::Postfix(p) => {
|
||||
// Assume no functions in arrays, etc.
|
||||
@@ -417,7 +409,7 @@ impl<'ast> ZGen<'ast> {
|
||||
.circ
|
||||
.exit_fn()
|
||||
.map(|a| a.unwrap_term())
|
||||
.unwrap_or_else(|| Self::const_bool(false));
|
||||
.unwrap_or_else(|| z_bool_lit(false));
|
||||
self.file_stack.pop();
|
||||
ret
|
||||
};
|
||||
@@ -510,9 +502,7 @@ impl<'ast> ZGen<'ast> {
|
||||
let t = ret_terms.into_iter().next().unwrap();
|
||||
match check(&t) {
|
||||
Sort::BitVector(_) => {}
|
||||
s => {
|
||||
panic!("Cannot maximize output of type {}", s)
|
||||
}
|
||||
s => panic!("Cannot maximize output of type {}", s),
|
||||
}
|
||||
self.circ.cir_ctx().cs.borrow_mut().outputs.push(t);
|
||||
}
|
||||
@@ -525,12 +515,8 @@ impl<'ast> ZGen<'ast> {
|
||||
);
|
||||
let t = ret_terms.into_iter().next().unwrap();
|
||||
let cmp = match check(&t) {
|
||||
Sort::BitVector(w) => {
|
||||
term![BV_UGE; t, bv_lit(v, w)]
|
||||
}
|
||||
s => {
|
||||
panic!("Cannot maximize output of type {}", s)
|
||||
}
|
||||
Sort::BitVector(w) => term![BV_UGE; t, bv_lit(v, w)],
|
||||
s => panic!("Cannot maximize output of type {}", s),
|
||||
};
|
||||
self.circ.cir_ctx().cs.borrow_mut().outputs.push(cmp);
|
||||
}
|
||||
@@ -669,13 +655,12 @@ impl<'ast> ZGen<'ast> {
|
||||
);
|
||||
}
|
||||
for s in &f.structs {
|
||||
let ty = Ty::Struct(
|
||||
let ty = Ty::new_struct(
|
||||
s.id.value.clone(),
|
||||
s.fields
|
||||
.clone()
|
||||
.iter()
|
||||
.map(|f| (f.id.value.clone(), self.type_(&f.ty)))
|
||||
.collect(),
|
||||
.map(|f| (f.id.value.clone(), self.type_(&f.ty))),
|
||||
);
|
||||
debug!("struct {}", s.id.value);
|
||||
self.circ.def_type(&s.id.value, ty);
|
||||
|
||||
@@ -19,6 +19,8 @@ lazy_static! {
|
||||
.unwrap();
|
||||
/// The modulus for ZoKrates, as an ARC
|
||||
pub static ref ZOKRATES_MODULUS_ARC: Arc<Integer> = Arc::new(ZOKRATES_MODULUS.clone());
|
||||
/// The modulus for ZoKrates, as an IR sort
|
||||
pub static ref ZOK_FIELD_SORT: Sort = Sort::Field(ZOKRATES_MODULUS_ARC.clone());
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
@@ -26,17 +28,57 @@ pub enum Ty {
|
||||
Uint(usize),
|
||||
Bool,
|
||||
Field,
|
||||
Struct(String, BTreeMap<String, Ty>),
|
||||
Struct(String, FieldList<Ty>),
|
||||
Array(usize, Box<Ty>),
|
||||
}
|
||||
|
||||
pub use field_list::FieldList;
|
||||
|
||||
/// This module contains [FieldList].
|
||||
///
|
||||
/// It gets its own module so that its member can be private.
|
||||
mod field_list {
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct FieldList<T> {
|
||||
// must be kept in sorted order
|
||||
list: Vec<(String, T)>,
|
||||
}
|
||||
|
||||
impl<T> FieldList<T> {
|
||||
pub fn new(mut list: Vec<(String, T)>) -> Self {
|
||||
list.sort_by_cached_key(|p| p.0.clone());
|
||||
FieldList { list }
|
||||
}
|
||||
pub fn search(&self, key: &str) -> Option<(usize, &T)> {
|
||||
let idx = self
|
||||
.list
|
||||
.binary_search_by_key(&key, |p| p.0.as_str())
|
||||
.ok()?;
|
||||
Some((idx, &self.list[idx].1))
|
||||
}
|
||||
pub fn get(&self, idx: usize) -> (&str, &T) {
|
||||
(&self.list[idx].0, &self.list[idx].1)
|
||||
}
|
||||
pub fn fields(&self) -> impl Iterator<Item = &(String, T)> {
|
||||
self.list.iter()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Ty {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Ty::Bool => write!(f, "bool"),
|
||||
Ty::Uint(w) => write!(f, "u{}", w),
|
||||
Ty::Field => write!(f, "field"),
|
||||
Ty::Struct(n, _) => write!(f, "{}", n),
|
||||
Ty::Struct(n, fields) => {
|
||||
let mut o = f.debug_struct(n);
|
||||
for (f_name, f_ty) in fields.fields() {
|
||||
o.field(f_name, f_ty);
|
||||
}
|
||||
o.finish()
|
||||
}
|
||||
Ty::Array(n, b) => write!(f, "{}[{}]", b, n),
|
||||
}
|
||||
}
|
||||
@@ -49,88 +91,113 @@ impl fmt::Debug for Ty {
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
fn default(&self) -> T {
|
||||
fn sort(&self) -> Sort {
|
||||
match self {
|
||||
Self::Bool => T::Bool(leaf_term(Op::Const(Value::Bool(false)))),
|
||||
Self::Uint(w) => T::Uint(*w, bv_lit(0, *w)),
|
||||
Self::Field => T::Field(pf_lit(0)),
|
||||
Self::Array(n, b) => T::Array((**b).clone(), vec![b.default(); *n]),
|
||||
Self::Struct(n, fs) => T::Struct(
|
||||
n.clone(),
|
||||
fs.iter()
|
||||
.map(|(f_name, f_ty)| (f_name.to_owned(), f_ty.default()))
|
||||
.collect(),
|
||||
),
|
||||
Self::Bool => Sort::Bool,
|
||||
Self::Uint(w) => Sort::BitVector(*w),
|
||||
Self::Field => ZOK_FIELD_SORT.clone(),
|
||||
Self::Array(n, b) => {
|
||||
Sort::Array(Box::new(ZOK_FIELD_SORT.clone()), Box::new(b.sort()), *n)
|
||||
}
|
||||
Self::Struct(_name, fs) => {
|
||||
Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
fn default_ir_term(&self) -> Term {
|
||||
self.sort().default_term()
|
||||
}
|
||||
fn default(&self) -> T {
|
||||
T {
|
||||
term: self.default_ir_term(),
|
||||
ty: self.clone(),
|
||||
}
|
||||
}
|
||||
/// Creates a new structure type, sorting the keys.
|
||||
pub fn new_struct<I: IntoIterator<Item = (String, Ty)>>(name: String, fields: I) -> Self {
|
||||
Self::Struct(name, FieldList::new(fields.into_iter().collect()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum T {
|
||||
Uint(usize, Term),
|
||||
Bool(Term),
|
||||
Field(Term),
|
||||
/// TODO: special case primitive arrays with Vec<T>.
|
||||
Array(Ty, Vec<T>),
|
||||
Struct(String, BTreeMap<String, T>),
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct T {
|
||||
pub ty: Ty,
|
||||
pub term: Term,
|
||||
}
|
||||
|
||||
impl T {
|
||||
pub fn type_(&self) -> Ty {
|
||||
match self {
|
||||
T::Uint(w, _) => Ty::Uint(*w),
|
||||
T::Bool(_) => Ty::Bool,
|
||||
T::Field(_) => Ty::Field,
|
||||
T::Array(b, v) => Ty::Array(v.len(), Box::new(b.clone())),
|
||||
T::Struct(name, map) => Ty::Struct(
|
||||
name.clone(),
|
||||
map.iter()
|
||||
.map(|(f_name, f_term)| (f_name.clone(), f_term.type_()))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
pub fn new(ty: Ty, term: Term) -> Self {
|
||||
Self { ty, term }
|
||||
}
|
||||
pub fn type_(&self) -> &Ty {
|
||||
&self.ty
|
||||
}
|
||||
/// Get all IR terms inside this value, as a list.
|
||||
pub fn terms(&self) -> Vec<Term> {
|
||||
let mut output: Vec<Term> = Vec::new();
|
||||
fn terms_tail(term: &T, output: &mut Vec<Term>) {
|
||||
match term {
|
||||
T::Bool(b) => output.push(b.clone()),
|
||||
T::Uint(_, b) => output.push(b.clone()),
|
||||
T::Field(b) => output.push(b.clone()),
|
||||
T::Array(_, v) => v.iter().for_each(|v| terms_tail(v, output)),
|
||||
T::Struct(_, map) => map.iter().for_each(|(_, v)| terms_tail(v, output)),
|
||||
fn terms_tail(term: &Term, output: &mut Vec<Term>) {
|
||||
match check(term) {
|
||||
Sort::Bool | Sort::BitVector(_) | Sort::Field(_) => output.push(term.clone()),
|
||||
Sort::Array(_k, _v, size) => {
|
||||
for i in 0..size {
|
||||
terms_tail(&term![Op::Select; term.clone(), pf_lit_ir(i)], output)
|
||||
}
|
||||
}
|
||||
Sort::Tuple(sorts) => {
|
||||
for i in 0..sorts.len() {
|
||||
terms_tail(&term![Op::Field(i); term.clone()], output)
|
||||
}
|
||||
}
|
||||
s => unreachable!("Unreachable IR sort {} in ZoK", s),
|
||||
}
|
||||
}
|
||||
terms_tail(self, &mut output);
|
||||
terms_tail(&self.term, &mut output);
|
||||
output
|
||||
}
|
||||
fn unwrap_array_ir(self) -> Result<Vec<Term>, String> {
|
||||
match &self.ty {
|
||||
Ty::Array(size, _sort) => Ok((0..*size)
|
||||
.map(|i| term![Op::Select; self.term.clone(), pf_lit_ir(i)])
|
||||
.collect()),
|
||||
s => Err(format!("Not an array: {}", s)),
|
||||
}
|
||||
}
|
||||
pub fn unwrap_array(self) -> Result<Vec<T>, String> {
|
||||
match self {
|
||||
T::Array(_, v) => Ok(v),
|
||||
match &self.ty {
|
||||
Ty::Array(_size, sort) => {
|
||||
let sort = (**sort).clone();
|
||||
Ok(self
|
||||
.unwrap_array_ir()?
|
||||
.into_iter()
|
||||
.map(|t| T::new(sort.clone(), t))
|
||||
.collect())
|
||||
}
|
||||
s => Err(format!("Not an array: {}", s)),
|
||||
}
|
||||
}
|
||||
pub fn new_array(v: Vec<T>) -> Result<T, String> {
|
||||
array(v)
|
||||
}
|
||||
pub fn new_struct(name: String, fields: Vec<(String, T)>) -> T {
|
||||
let (field_tys, ir_terms): (Vec<_>, Vec<_>) = fields
|
||||
.into_iter()
|
||||
.map(|(name, t)| ((name.clone(), t.ty), (name, t.term)))
|
||||
.unzip();
|
||||
let field_ty_list = FieldList::new(field_tys);
|
||||
let ir_term = term(Op::Tuple, {
|
||||
let with_indices: BTreeMap<usize, Term> = ir_terms
|
||||
.into_iter()
|
||||
.map(|(name, t)| (field_ty_list.search(&name).unwrap().0, t))
|
||||
.collect();
|
||||
with_indices.into_iter().map(|(_i, t)| t).collect()
|
||||
});
|
||||
T::new(Ty::Struct(name, field_ty_list), ir_term)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for T {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
T::Bool(x) => write!(f, "Bool({})", x),
|
||||
T::Uint(_, x) => write!(f, "Uint({})", x),
|
||||
T::Field(x) => write!(f, "Field({})", x),
|
||||
T::Struct(_, _) => write!(f, "struct"),
|
||||
T::Array(_, _) => write!(f, "array"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for T {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
write!(f, "{}", self.term)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,10 +209,16 @@ fn wrap_bin_op(
|
||||
a: T,
|
||||
b: T,
|
||||
) -> Result<T, String> {
|
||||
match (a, b, fu, ff, fb) {
|
||||
(T::Uint(na, a), T::Uint(nb, b), Some(fu), _, _) if na == nb => Ok(T::Uint(na, fu(a, b))),
|
||||
(T::Bool(a), T::Bool(b), _, _, Some(fb)) => Ok(T::Bool(fb(a, b))),
|
||||
(T::Field(a), T::Field(b), _, Some(ff), _) => Ok(T::Field(ff(a, b))),
|
||||
match (&a.ty, &b.ty, fu, ff, fb) {
|
||||
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _) if na == nb => {
|
||||
Ok(T::new(Ty::Uint(*na), fu(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(Ty::Bool, Ty::Bool, _, _, Some(fb)) => {
|
||||
Ok(T::new(Ty::Bool, fb(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(Ty::Field, Ty::Field, _, Some(ff), _) => {
|
||||
Ok(T::new(Ty::Field, ff(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(x, y, _, _, _) => Err(format!("Cannot perform op '{}' on {} and {}", name, x, y)),
|
||||
}
|
||||
}
|
||||
@@ -158,10 +231,16 @@ fn wrap_bin_pred(
|
||||
a: T,
|
||||
b: T,
|
||||
) -> Result<T, String> {
|
||||
match (a, b, fu, ff, fb) {
|
||||
(T::Uint(na, a), T::Uint(nb, b), Some(fu), _, _) if na == nb => Ok(T::Bool(fu(a, b))),
|
||||
(T::Bool(a), T::Bool(b), _, _, Some(fb)) => Ok(T::Bool(fb(a, b))),
|
||||
(T::Field(a), T::Field(b), _, Some(ff), _) => Ok(T::Bool(ff(a, b))),
|
||||
match (&a.ty, &b.ty, fu, ff, fb) {
|
||||
(Ty::Uint(na), Ty::Uint(nb), Some(fu), _, _) if na == nb => {
|
||||
Ok(T::new(Ty::Bool, fu(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(Ty::Bool, Ty::Bool, _, _, Some(fb)) => {
|
||||
Ok(T::new(Ty::Bool, fb(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(Ty::Field, Ty::Field, _, Some(ff), _) => {
|
||||
Ok(T::new(Ty::Bool, ff(a.term.clone(), b.term.clone())))
|
||||
}
|
||||
(x, y, _, _, _) => Err(format!("Cannot perform op '{}' on {} and {}", name, x, y)),
|
||||
}
|
||||
}
|
||||
@@ -317,10 +396,10 @@ fn wrap_un_op(
|
||||
fb: Option<fn(Term) -> Term>,
|
||||
a: T,
|
||||
) -> Result<T, String> {
|
||||
match (a, fu, ff, fb) {
|
||||
(T::Uint(na, a), Some(fu), _, _) => Ok(T::Uint(na, fu(a))),
|
||||
(T::Bool(a), _, _, Some(fb)) => Ok(T::Bool(fb(a))),
|
||||
(T::Field(a), _, Some(ff), _) => Ok(T::Field(ff(a))),
|
||||
match (&a.ty, fu, ff, fb) {
|
||||
(Ty::Uint(_), Some(fu), _, _) => Ok(T::new(a.ty.clone(), fu(a.term.clone()))),
|
||||
(Ty::Bool, _, _, Some(fb)) => Ok(T::new(Ty::Bool, fb(a.term.clone()))),
|
||||
(Ty::Field, _, Some(ff), _) => Ok(T::new(Ty::Field, ff(a.term.clone()))),
|
||||
(x, _, _, _) => Err(format!("Cannot perform op '{}' on {}", name, x)),
|
||||
}
|
||||
}
|
||||
@@ -352,31 +431,25 @@ pub fn not(a: T) -> Result<T, String> {
|
||||
}
|
||||
|
||||
pub fn const_int(a: T) -> Result<Integer, String> {
|
||||
let s = match &a {
|
||||
T::Field(b) => match &b.op {
|
||||
Op::Const(Value::Field(f)) => Some(f.i().clone()),
|
||||
_ => None,
|
||||
},
|
||||
T::Uint(_, i) => match &i.op {
|
||||
Op::Const(Value::BitVector(f)) => Some(f.uint().clone()),
|
||||
_ => None,
|
||||
},
|
||||
match &a.term.op {
|
||||
Op::Const(Value::Field(f)) => Some(f.i().clone()),
|
||||
Op::Const(Value::BitVector(f)) => Some(f.uint().clone()),
|
||||
_ => None,
|
||||
};
|
||||
s.ok_or_else(|| format!("{} is not a constant integer", a))
|
||||
}
|
||||
.ok_or_else(|| format!("{} is not a constant integer", a))
|
||||
}
|
||||
|
||||
pub fn bool(a: T) -> Result<Term, String> {
|
||||
match a {
|
||||
T::Bool(b) => Ok(b),
|
||||
match &a.ty {
|
||||
Ty::Bool => Ok(a.term),
|
||||
a => Err(format!("{} is not a boolean", a)),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_shift(name: &str, op: BvBinOp, a: T, b: T) -> Result<T, String> {
|
||||
let bc = const_int(b)?;
|
||||
match a {
|
||||
T::Uint(na, a) => Ok(T::Uint(na, term![Op::BvBinOp(op); a, bv_lit(bc, na)])),
|
||||
match &a.ty {
|
||||
&Ty::Uint(na) => Ok(T::new(a.ty, term![Op::BvBinOp(op); a.term, bv_lit(bc, na)])),
|
||||
x => Err(format!("Cannot perform op '{}' on {} and {}", name, x, bc)),
|
||||
}
|
||||
}
|
||||
@@ -390,30 +463,10 @@ pub fn shr(a: T, b: T) -> Result<T, String> {
|
||||
}
|
||||
|
||||
fn ite(c: Term, a: T, b: T) -> Result<T, String> {
|
||||
match (a, b) {
|
||||
(T::Uint(na, a), T::Uint(nb, b)) if na == nb => Ok(T::Uint(na, term![Op::Ite; c, a, b])),
|
||||
(T::Bool(a), T::Bool(b)) => Ok(T::Bool(term![Op::Ite; c, a, b])),
|
||||
(T::Field(a), T::Field(b)) => Ok(T::Field(term![Op::Ite; c, a, b])),
|
||||
(T::Array(ta, a), T::Array(tb, b)) if a.len() == b.len() && ta == tb => Ok(T::Array(
|
||||
ta,
|
||||
a.into_iter()
|
||||
.zip(b.into_iter())
|
||||
.map(|(a_i, b_i)| ite(c.clone(), a_i, b_i))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)),
|
||||
(T::Struct(na, a), T::Struct(nb, b)) if na == nb => Ok(T::Struct(na.clone(), {
|
||||
a.into_iter()
|
||||
.zip(b.into_iter())
|
||||
.map(|((af, av), (bf, bv))| {
|
||||
if af == bf {
|
||||
Ok((af, ite(c.clone(), av, bv)?))
|
||||
} else {
|
||||
Err(format!("Field mismatch: {} vs {}", af, bf))
|
||||
}
|
||||
})
|
||||
.collect::<Result<BTreeMap<_, _>, String>>()?
|
||||
})),
|
||||
(x, y) => Err(format!("Cannot perform ITE on {} and {}", x, y)),
|
||||
if &a.ty != &b.ty {
|
||||
Err(format!("Cannot perform ITE on {} and {}", a, b))
|
||||
} else {
|
||||
Ok(T::new(a.ty.clone(), term![Op::Ite; c, a.term, b.term]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,7 +474,7 @@ pub fn cond(c: T, a: T, b: T) -> Result<T, String> {
|
||||
ite(bool(c)?, a, b)
|
||||
}
|
||||
|
||||
pub fn pf_lit<I>(i: I) -> Term
|
||||
pub fn pf_lit_ir<I>(i: I) -> Term
|
||||
where
|
||||
Integer: From<I>,
|
||||
{
|
||||
@@ -431,76 +484,111 @@ where
|
||||
))))
|
||||
}
|
||||
|
||||
pub fn slice(array: T, start: Option<usize>, end: Option<usize>) -> Result<T, String> {
|
||||
match array {
|
||||
T::Array(b, mut list) => {
|
||||
pub fn field_lit<I>(i: I) -> T
|
||||
where
|
||||
Integer: From<I>,
|
||||
{
|
||||
T::new(Ty::Field, pf_lit_ir(i))
|
||||
}
|
||||
|
||||
pub fn z_bool_lit(v: bool) -> T {
|
||||
T::new(Ty::Bool, leaf_term(Op::Const(Value::Bool(v))))
|
||||
}
|
||||
|
||||
pub fn uint_lit<I>(v: I, bits: usize) -> T
|
||||
where
|
||||
Integer: From<I>,
|
||||
{
|
||||
T::new(Ty::Uint(bits), bv_lit(v, bits))
|
||||
}
|
||||
|
||||
pub fn slice(arr: T, start: Option<usize>, end: Option<usize>) -> Result<T, String> {
|
||||
match &arr.ty {
|
||||
Ty::Array(size, _) => {
|
||||
let start = start.unwrap_or(0);
|
||||
let end = end.unwrap_or(list.len());
|
||||
Ok(T::Array(b, list.drain(start..end).collect()))
|
||||
let end = end.unwrap_or(*size);
|
||||
array(arr.unwrap_array()?.drain(start..end))
|
||||
}
|
||||
a => Err(format!("Cannot slice {}", a)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn field_select(struct_: &T, field: &str) -> Result<T, String> {
|
||||
match struct_ {
|
||||
T::Struct(_, map) => map
|
||||
.get(field)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("No field '{}'", field)),
|
||||
match &struct_.ty {
|
||||
Ty::Struct(_, map) => {
|
||||
if let Some((idx, ty)) = map.search(field) {
|
||||
Ok(T::new(
|
||||
ty.clone(),
|
||||
term![Op::Field(idx); struct_.term.clone()],
|
||||
))
|
||||
} else {
|
||||
Err(format!("No field '{}'", field))
|
||||
}
|
||||
}
|
||||
a => Err(format!("{} is not a struct", a)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn field_store(struct_: T, field: &str, val: T) -> Result<T, String> {
|
||||
match struct_ {
|
||||
T::Struct(name, mut map) => Ok(T::Struct(name, {
|
||||
if map.insert(field.to_owned(), val).is_some() {
|
||||
map
|
||||
match &struct_.ty {
|
||||
Ty::Struct(_, map) => {
|
||||
if let Some((idx, ty)) = map.search(field) {
|
||||
if ty == &val.ty {
|
||||
Ok(T::new(
|
||||
struct_.ty.clone(),
|
||||
term![Op::Update(idx); struct_.term.clone(), val.term],
|
||||
))
|
||||
} else {
|
||||
Err(format!(
|
||||
"term {} assigned to field {} of type {}",
|
||||
val,
|
||||
field,
|
||||
map.get(idx).1
|
||||
))
|
||||
}
|
||||
} else {
|
||||
return Err(format!("No '{}' field", field));
|
||||
Err(format!("No field '{}'", field))
|
||||
}
|
||||
})),
|
||||
}
|
||||
a => Err(format!("{} is not a struct", a)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn array_select(array: T, idx: T) -> Result<T, String> {
|
||||
match (array, idx) {
|
||||
(T::Array(_, list), T::Field(idx)) => {
|
||||
let mut it = list.into_iter().enumerate();
|
||||
let first = it
|
||||
.next()
|
||||
.ok_or_else(|| format!("Cannot index empty array"))?;
|
||||
it.fold(Ok(first.1), |acc, (i, elem)| {
|
||||
ite(term![Op::Eq; pf_lit(i), idx.clone()], elem, acc?)
|
||||
})
|
||||
match (array.ty, idx.ty) {
|
||||
(Ty::Array(_size, elem_ty), Ty::Field) => {
|
||||
Ok(T::new(*elem_ty, term![Op::Select; array.term, idx.term]))
|
||||
}
|
||||
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> {
|
||||
match (array, idx) {
|
||||
(T::Array(ty, list), T::Field(idx)) => Ok(T::Array(
|
||||
ty,
|
||||
list.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| ite(term![Op::Eq; pf_lit(i), idx.clone()], val.clone(), elem))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
match (&array.ty, idx.ty) {
|
||||
(Ty::Array(_, _), Ty::Field) => Ok(T::new(
|
||||
array.ty,
|
||||
term![Op::Store; array.term, idx.term, val.term],
|
||||
)),
|
||||
(a, b) => Err(format!("Cannot index {} by {}", b, a)),
|
||||
}
|
||||
}
|
||||
|
||||
fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
|
||||
fn ir_array<I: IntoIterator<Item = Term>>(sort: Sort, elems: I) -> Term {
|
||||
make_array(ZOK_FIELD_SORT.clone(), sort, elems.into_iter().collect())
|
||||
}
|
||||
|
||||
pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
|
||||
let v: Vec<T> = elems.into_iter().collect();
|
||||
if let Some(e) = v.first() {
|
||||
let ty = e.type_();
|
||||
if v.iter().skip(1).any(|a| a.type_() != ty) {
|
||||
Err(format!("Inconsistent types in array"))
|
||||
} else {
|
||||
Ok(T::Array(ty, v))
|
||||
let sort = check(&e.term);
|
||||
Ok(T::new(
|
||||
Ty::Array(v.len(), Box::new(ty.clone())),
|
||||
ir_array(sort, v.into_iter().map(|t| t.term)),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(format!("Empty array"))
|
||||
@@ -508,27 +596,29 @@ fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> {
|
||||
}
|
||||
|
||||
pub fn uint_to_bits(u: T) -> Result<T, String> {
|
||||
match u {
|
||||
T::Uint(n, t) => Ok(T::Array(
|
||||
Ty::Bool,
|
||||
(0..n)
|
||||
.map(|i| T::Bool(term![Op::BvBit(i); t.clone()]))
|
||||
.collect(),
|
||||
match &u.ty {
|
||||
Ty::Uint(n) => Ok(T::new(
|
||||
Ty::Array(*n, Box::new(Ty::Bool)),
|
||||
ir_array(
|
||||
Sort::Bool,
|
||||
(0..*n).map(|i| term![Op::BvBit(i); u.term.clone()]),
|
||||
),
|
||||
)),
|
||||
u => Err(format!("Cannot do uint-to-bits on {}", u)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn uint_from_bits(u: T) -> Result<T, String> {
|
||||
match u {
|
||||
T::Array(Ty::Bool, list) => match list.len() {
|
||||
8 | 16 | 32 => Ok(T::Uint(
|
||||
list.len(),
|
||||
match &u.ty {
|
||||
Ty::Array(bits, elem_ty) if &**elem_ty == &Ty::Bool => match bits {
|
||||
8 | 16 | 32 => Ok(T::new(
|
||||
Ty::Uint(*bits),
|
||||
term(
|
||||
Op::BvConcat,
|
||||
list.into_iter()
|
||||
.map(|z: T| -> Result<Term, String> { Ok(term![Op::BoolToBv; bool(z)?]) })
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
u.unwrap_array_ir()?
|
||||
.into_iter()
|
||||
.map(|z: Term| -> Term { term![Op::BoolToBv; z] })
|
||||
.collect(),
|
||||
),
|
||||
)),
|
||||
l => Err(format!("Cannot do uint-from-bits on len {} array", l,)),
|
||||
@@ -538,17 +628,9 @@ pub fn uint_from_bits(u: T) -> Result<T, String> {
|
||||
}
|
||||
|
||||
pub fn field_to_bits(f: T) -> Result<T, String> {
|
||||
match f {
|
||||
T::Field(t) => {
|
||||
let u = term![Op::PfToBv(254); t];
|
||||
Ok(T::Array(
|
||||
Ty::Bool,
|
||||
(0..254)
|
||||
.map(|i| T::Bool(term![Op::BvBit(i); u.clone()]))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
u => Err(format!("Cannot do field-to-bits on {}", u)),
|
||||
match &f.ty {
|
||||
Ty::Field => uint_to_bits(T::new(Ty::Uint(254), term![Op::PfToBv(254); f.term])),
|
||||
u => Err(format!("Cannot do uint-to-bits on {}", u)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,20 +680,26 @@ impl Embeddable for ZoKrates {
|
||||
.unwrap_or_else(|| Integer::from(0))
|
||||
};
|
||||
match ty {
|
||||
Ty::Bool => T::Bool(ctx.cs.borrow_mut().new_var(
|
||||
&raw_name,
|
||||
Sort::Bool,
|
||||
|| Value::Bool(get_int_val() != 0),
|
||||
visibility,
|
||||
)),
|
||||
Ty::Field => T::Field(ctx.cs.borrow_mut().new_var(
|
||||
&raw_name,
|
||||
Sort::Field(self.modulus.clone()),
|
||||
|| Value::Field(FieldElem::new(get_int_val(), self.modulus.clone())),
|
||||
visibility,
|
||||
)),
|
||||
Ty::Uint(w) => T::Uint(
|
||||
*w,
|
||||
Ty::Bool => T::new(
|
||||
Ty::Bool,
|
||||
ctx.cs.borrow_mut().new_var(
|
||||
&raw_name,
|
||||
Sort::Bool,
|
||||
|| Value::Bool(get_int_val() != 0),
|
||||
visibility,
|
||||
),
|
||||
),
|
||||
Ty::Field => T::new(
|
||||
Ty::Field,
|
||||
ctx.cs.borrow_mut().new_var(
|
||||
&raw_name,
|
||||
Sort::Field(self.modulus.clone()),
|
||||
|| Value::Field(FieldElem::new(get_int_val(), self.modulus.clone())),
|
||||
visibility,
|
||||
),
|
||||
),
|
||||
Ty::Uint(w) => T::new(
|
||||
Ty::Uint(*w),
|
||||
ctx.cs.borrow_mut().new_var(
|
||||
&raw_name,
|
||||
Sort::BitVector(*w),
|
||||
@@ -619,23 +707,19 @@ impl Embeddable for ZoKrates {
|
||||
visibility,
|
||||
),
|
||||
),
|
||||
Ty::Array(n, ty) => T::Array(
|
||||
(**ty).clone(),
|
||||
(0..*n)
|
||||
.map(|i| {
|
||||
self.declare(
|
||||
ctx,
|
||||
&*ty,
|
||||
idx_name(&raw_name, i),
|
||||
user_name.as_ref().map(|u| idx_name(u, i)),
|
||||
visibility.clone(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
Ty::Struct(n, fs) => T::Struct(
|
||||
Ty::Array(n, ty) => array((0..*n).map(|i| {
|
||||
self.declare(
|
||||
ctx,
|
||||
&*ty,
|
||||
idx_name(&raw_name, i),
|
||||
user_name.as_ref().map(|u| idx_name(u, i)),
|
||||
visibility.clone(),
|
||||
)
|
||||
}))
|
||||
.unwrap(),
|
||||
Ty::Struct(n, fs) => T::new_struct(
|
||||
n.clone(),
|
||||
fs.iter()
|
||||
fs.fields()
|
||||
.map(|(f_name, f_ty)| {
|
||||
(
|
||||
f_name.clone(),
|
||||
@@ -652,33 +736,8 @@ impl Embeddable for ZoKrates {
|
||||
),
|
||||
}
|
||||
}
|
||||
fn ite(&self, ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
|
||||
match (t, f) {
|
||||
(T::Bool(a), T::Bool(b)) => T::Bool(term![Op::Ite; cond, a, b]),
|
||||
(T::Uint(wa, a), T::Uint(wb, b)) if wa == wb => T::Uint(wa, term![Op::Ite; cond, a, b]),
|
||||
(T::Field(a), T::Field(b)) => T::Field(term![Op::Ite; cond, a, b]),
|
||||
(T::Array(a_ty, a), T::Array(b_ty, b)) if a_ty == b_ty => T::Array(
|
||||
a_ty,
|
||||
a.into_iter()
|
||||
.zip(b.into_iter())
|
||||
.map(|(a_i, b_i)| self.ite(ctx, cond.clone(), a_i, b_i))
|
||||
.collect(),
|
||||
),
|
||||
(T::Struct(a_nm, a), T::Struct(b_nm, b)) if a_nm == b_nm => T::Struct(
|
||||
a_nm,
|
||||
a.into_iter()
|
||||
.zip(b.into_iter())
|
||||
.map(|((a_f, a_i), (b_f, b_i))| {
|
||||
if a_f == b_f {
|
||||
(a_f, self.ite(ctx, cond.clone(), a_i, b_i))
|
||||
} else {
|
||||
panic!("Field mismatch: '{}' vs '{}'", a_f, b_f)
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
(t, f) => panic!("Cannot ITE {} and {}", t, f),
|
||||
}
|
||||
fn ite(&self, _ctx: &mut CirCtx, cond: Term, t: Self::T, f: Self::T) -> Self::T {
|
||||
ite(cond, t, f).unwrap()
|
||||
}
|
||||
fn assign(
|
||||
&self,
|
||||
@@ -688,47 +747,15 @@ impl Embeddable for ZoKrates {
|
||||
t: Self::T,
|
||||
visibility: Option<PartyId>,
|
||||
) -> Self::T {
|
||||
assert!(&t.type_() == ty);
|
||||
match (ty, t) {
|
||||
(_, T::Bool(b)) => T::Bool(ctx.cs.borrow_mut().assign(&name, b, visibility)),
|
||||
(_, T::Field(b)) => T::Field(ctx.cs.borrow_mut().assign(&name, b, visibility)),
|
||||
(_, T::Uint(w, b)) => T::Uint(w, ctx.cs.borrow_mut().assign(&name, b, visibility)),
|
||||
(_, T::Array(ety, list)) => T::Array(
|
||||
ety.clone(),
|
||||
list.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| {
|
||||
self.assign(ctx, &ety, idx_name(&name, i), elem, visibility.clone())
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
(Ty::Struct(_, tys), T::Struct(s_name, list)) => T::Struct(
|
||||
s_name,
|
||||
list.into_iter()
|
||||
.zip(tys.into_iter())
|
||||
.map(|((f_name, elem), (_, f_ty))| {
|
||||
(
|
||||
f_name.clone(),
|
||||
self.assign(
|
||||
ctx,
|
||||
&f_ty,
|
||||
field_name(&name, &f_name),
|
||||
elem,
|
||||
visibility.clone(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
assert!(t.type_() == ty);
|
||||
T::new(t.ty, ctx.cs.borrow_mut().assign(&name, t.term, visibility))
|
||||
}
|
||||
fn values(&self) -> bool {
|
||||
self.values.is_some()
|
||||
}
|
||||
|
||||
fn type_of(&self, term: &Self::T) -> Self::Ty {
|
||||
term.type_()
|
||||
term.type_().clone()
|
||||
}
|
||||
|
||||
fn initialize_return(&self, ty: &Self::Ty, _ssa_name: &String) -> Self::T {
|
||||
|
||||
@@ -1,110 +1,95 @@
|
||||
//! Linear Memory implementation.
|
||||
//!
|
||||
//! The idea is to replace each array with a term sequence and use ITEs to linearly scan the array
|
||||
//! when needed. A SELECT produces an ITE reduce chain, a STORE produces an ITE map over the
|
||||
//! sequence.
|
||||
//!
|
||||
//! E.g., for length-3 arrays.
|
||||
//!
|
||||
//! (select A k) => (ite (= k 2) A2 (ite (= k 1) A1 A0))
|
||||
//! (store A k v) => (ite (= k 0) v A0), (ite (= k 1) v A1), (ite (= k 2))
|
||||
|
||||
use super::visit::MemVisitor;
|
||||
//! The idea is to replace each array with a tuple, and use ITEs to account for variable indexing.
|
||||
use super::super::visit::RewritePass;
|
||||
use crate::ir::term::*;
|
||||
|
||||
use std::iter::repeat;
|
||||
struct Linearizer;
|
||||
|
||||
struct ArrayLinearizer {
|
||||
/// A map from (original) replaced terms, to what they were replaced with.
|
||||
sequences: TermMap<Vec<Term>>,
|
||||
/// The maximum size of arrays that will be replaced.
|
||||
size_thresh: usize,
|
||||
fn arr_val_to_tup(v: &Value) -> Value {
|
||||
match v {
|
||||
Value::Array(Array {
|
||||
default, map, size, ..
|
||||
}) => Value::Tuple({
|
||||
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
|
||||
for (i, v) in map {
|
||||
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
|
||||
}
|
||||
vec
|
||||
}),
|
||||
v => v.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
impl MemVisitor for ArrayLinearizer {
|
||||
fn visit_const_array(&mut self, orig: &Term, _key_sort: &Sort, val: &Term, size: usize) {
|
||||
if size <= self.size_thresh {
|
||||
self.sequences
|
||||
.insert(orig.clone(), repeat(val).cloned().take(size).collect());
|
||||
}
|
||||
fn arr_sort_to_tup(v: &Sort) -> Sort {
|
||||
match v {
|
||||
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
|
||||
v => v.clone(),
|
||||
}
|
||||
fn visit_eq(&mut self, orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
|
||||
// don't map b/c self borrow lifetime & NLL
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let b_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent eq");
|
||||
let eqs: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Eq; a.clone(), b.clone()])
|
||||
.collect();
|
||||
Some(term(Op::BoolNaryOp(BoolNaryOp::And), eqs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn visit_ite(&mut self, orig: &Term, c: &Term, _t: &Term, _f: &Term) {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[1]) {
|
||||
let b_seq = self.sequences.get(&orig.cs[2]).expect("inconsistent ite");
|
||||
let ites: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Ite; c.clone(), a.clone(), b.clone()])
|
||||
.collect();
|
||||
self.sequences.insert(orig.clone(), ites);
|
||||
}
|
||||
}
|
||||
fn visit_store(&mut self, orig: &Term, _a: &Term, k: &Term, v: &Term) {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let key_sort = check(k);
|
||||
let ites: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(key_sort.elems_iter())
|
||||
.map(|(a_i, key_i)| {
|
||||
let eq_idx = term![Op::Eq; key_i, k.clone()];
|
||||
term![Op::Ite; eq_idx, v.clone(), a_i.clone()]
|
||||
})
|
||||
.collect();
|
||||
self.sequences.insert(orig.clone(), ites);
|
||||
}
|
||||
}
|
||||
fn visit_select(&mut self, orig: &Term, _a: &Term, k: &Term) -> Option<Term> {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let key_sort = check(k);
|
||||
let first = a_seq.first().expect("empty array in visit_select").clone();
|
||||
Some(a_seq.iter().zip(key_sort.elems_iter()).skip(1).fold(
|
||||
first,
|
||||
|acc, (a_i, key_i)| {
|
||||
let eq_idx = term![Op::Eq; key_i, k.clone()];
|
||||
term![Op::Ite; eq_idx, a_i.clone(), acc]
|
||||
},
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
|
||||
if let Sort::Array(_k, v, size) = s {
|
||||
if *size <= self.size_thresh {
|
||||
self.sequences.insert(
|
||||
orig.clone(),
|
||||
(0..*size)
|
||||
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
impl RewritePass for Linearizer {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
match &orig.op {
|
||||
Op::Const(v @ Value::Array(..)) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))),
|
||||
Op::Var(name, sort @ Sort::Array(_k, _v, _size)) => {
|
||||
let new_value = computation
|
||||
.values
|
||||
.as_ref()
|
||||
.map(|vs| arr_val_to_tup(vs.get(name).unwrap()));
|
||||
let vis = computation.metadata.get_input_visibility(name);
|
||||
let new_sort = arr_sort_to_tup(sort);
|
||||
let new_var_info = vec![(name.clone(), new_sort.clone(), new_value, vis)];
|
||||
computation.replace_input(orig.clone(), new_var_info);
|
||||
Some(leaf_term(Op::Var(name.clone(), new_sort)))
|
||||
}
|
||||
} else {
|
||||
unreachable!("should only visit array vars")
|
||||
Op::Select => {
|
||||
let cs = rewritten_children();
|
||||
let idx = &cs[1];
|
||||
let tup = &cs[0];
|
||||
if let Sort::Array(key_sort, _, size) = check(&orig.cs[0]) {
|
||||
assert!(size > 0);
|
||||
let mut fields = (0..size).map(|idx| term![Op::Field(idx); tup.clone()]);
|
||||
let first = fields.next().unwrap();
|
||||
Some(key_sort.elems_iter().take(size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
|
||||
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
|
||||
}))
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
Op::Store => {
|
||||
let cs = rewritten_children();
|
||||
let tup = &cs[0];
|
||||
let idx = &cs[1];
|
||||
let val = &cs[2];
|
||||
if let Sort::Array(key_sort, _, size) = check(&orig.cs[0]) {
|
||||
assert!(size > 0);
|
||||
let mut updates =
|
||||
(0..size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]);
|
||||
let first = updates.next().unwrap();
|
||||
Some(key_sort.elems_iter().take(size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| {
|
||||
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], update, acc]
|
||||
}))
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
// ITEs and EQs are correctly rewritten by default.
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Eliminate arrays using linear scans. See module documentation.
|
||||
pub fn linearize(t: &Term, size_thresh: usize) -> Term {
|
||||
let mut pass = ArrayLinearizer {
|
||||
size_thresh,
|
||||
sequences: TermMap::new(),
|
||||
};
|
||||
pass.traverse(t)
|
||||
pub fn linearize(c: &mut Computation) {
|
||||
let mut pass = Linearizer;
|
||||
pass.traverse(c);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -138,7 +123,12 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn select_ite_stores() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
|
||||
let z = term![Op::Const(Value::Array(Array::new(
|
||||
Sort::BitVector(4),
|
||||
Box::new(Sort::BitVector(4).default_value()),
|
||||
Default::default(),
|
||||
6
|
||||
)))];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
@@ -147,14 +137,21 @@ mod test {
|
||||
],
|
||||
bv_lit(3, 4)
|
||||
];
|
||||
let tt = linearize(&t, 6);
|
||||
assert!(array_free(&tt));
|
||||
assert_eq!(6 + 6 + 6 + 5, count_ites(&tt));
|
||||
let mut c = Computation::default();
|
||||
c.outputs.push(t);
|
||||
linearize(&mut c);
|
||||
assert!(array_free(&c.outputs[0]));
|
||||
assert_eq!(5 + 5 + 1 + 5, count_ites(&c.outputs[0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_ite_stores_field() {
|
||||
let z = term![Op::ConstArray(Sort::Field(Arc::new(Integer::from(TEST_FIELD))), 6); bv_lit(0, 4)];
|
||||
let z = term![Op::Const(Value::Array(Array::new(
|
||||
Sort::Field(Arc::new(Integer::from(TEST_FIELD))),
|
||||
Box::new(Sort::BitVector(4).default_value()),
|
||||
Default::default(),
|
||||
6
|
||||
)))];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
@@ -163,8 +160,10 @@ mod test {
|
||||
],
|
||||
field_lit(3)
|
||||
];
|
||||
let tt = linearize(&t, 6);
|
||||
assert!(array_free(&tt));
|
||||
assert_eq!(6 + 6 + 6 + 5, count_ites(&tt));
|
||||
let mut c = Computation::default();
|
||||
c.outputs.push(t);
|
||||
linearize(&mut c);
|
||||
assert!(array_free(&c.outputs[0]));
|
||||
assert_eq!(5 + 5 + 1 + 5, count_ites(&c.outputs[0]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
//! Memory optimizations
|
||||
|
||||
/// Oblivious array elimination.
|
||||
///
|
||||
/// Replace arrays with tuples, using ITEs to handle variable indexing.
|
||||
pub mod lin;
|
||||
/// Oblivious array elimination.
|
||||
///
|
||||
/// Replace arrays that are accessed at constant indices with tuples.
|
||||
pub mod obliv;
|
||||
mod visit;
|
||||
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// Eliminates arrays, first oblivious ones, and then all arrays.
|
||||
pub fn array_elim(t: &Term) -> Term {
|
||||
lin::linearize(&obliv::elim_obliv(t), usize::MAX)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Oblivious Array Elimination
|
||||
//!
|
||||
//! This module attempts to identify *oblivious* arrays: those that are only accessed at constant
|
||||
//! indices. These arrays can be replaced with normal terms.
|
||||
//! indices. These arrays can be replaced with tuples. Then, a tuple elimination pass can be run.
|
||||
//!
|
||||
//! It operates in two passes:
|
||||
//!
|
||||
//! 1. determine which arrays are oblivious
|
||||
//! 2. replace oblivious arrays with (haskell) lists of terms.
|
||||
//! 2. replace oblivious arrays with tuples
|
||||
//!
|
||||
//!
|
||||
//! ## Pass 1: Identifying oblivious arrays
|
||||
@@ -26,104 +26,116 @@
|
||||
//!
|
||||
//! In this pass, the goal is to
|
||||
//!
|
||||
//! * map array terms to haskell lists of value terms
|
||||
//! * map array selections to specific value terms
|
||||
//!
|
||||
//! The pass maintains:
|
||||
//!
|
||||
//! * a map from array terms to lists of values
|
||||
//!
|
||||
//! It then does a bottom-up formula traversal, performing the following
|
||||
//! transformations:
|
||||
//!
|
||||
//! * oblivious array variables are mapped to a list of (derivative) variables
|
||||
//! * oblivious constant arrays are mapped to a list that replicates the constant
|
||||
//! * accesses to oblivious arrays are mapped to the appropriate term from the
|
||||
//! value list of the array
|
||||
//! * stores to oblivious arrays are mapped to updated value lists
|
||||
//! * equalities between oblivious arrays are mapped to conjunctions of equalities
|
||||
//! * map array terms to tuple terms
|
||||
//! * map array selections to tuple field gets
|
||||
|
||||
use super::visit::*;
|
||||
use crate::ir::term::*;
|
||||
use super::super::visit::*;
|
||||
use crate::ir::term::extras::as_uint_constant;
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::debug;
|
||||
|
||||
use std::iter::repeat;
|
||||
|
||||
struct NonOblivComputer {
|
||||
not_obliv: TermSet,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
impl NonOblivComputer {
|
||||
fn mark(&mut self, a: &Term) {
|
||||
if !self.not_obliv.contains(a) {
|
||||
self.not_obliv.insert(a.clone());
|
||||
self.progress = true;
|
||||
fn mark(&mut self, a: &Term) -> bool {
|
||||
if !a.is_const() && self.not_obliv.insert(a.clone()) {
|
||||
debug!("Not obliv: {}", a);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn bi_implicate(&mut self, a: &Term, b: &Term) {
|
||||
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
|
||||
(false, true) => {
|
||||
self.not_obliv.insert(a.clone());
|
||||
self.progress = true;
|
||||
fn bi_implicate(&mut self, a: &Term, b: &Term) -> bool {
|
||||
if !a.is_const() && !b.is_const() {
|
||||
match (self.not_obliv.contains(a), self.not_obliv.contains(b)) {
|
||||
(false, true) => {
|
||||
self.not_obliv.insert(a.clone());
|
||||
true
|
||||
}
|
||||
(true, false) => {
|
||||
self.not_obliv.insert(b.clone());
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
(true, false) => {
|
||||
self.not_obliv.insert(b.clone());
|
||||
self.progress = true;
|
||||
}
|
||||
_ => {}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
progress: true,
|
||||
not_obliv: TermSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemVisitor for NonOblivComputer {
|
||||
fn visit_eq(&mut self, _orig: &Term, a: &Term, b: &Term) -> Option<Term> {
|
||||
self.bi_implicate(a, b);
|
||||
None
|
||||
}
|
||||
fn visit_ite(&mut self, orig: &Term, _c: &Term, t: &Term, f: &Term) {
|
||||
self.bi_implicate(orig, t);
|
||||
self.bi_implicate(t, f);
|
||||
self.bi_implicate(orig, f);
|
||||
}
|
||||
fn visit_store(&mut self, orig: &Term, a: &Term, k: &Term, _v: &Term) {
|
||||
if let Op::Const(_) = k.op {
|
||||
self.bi_implicate(orig, a);
|
||||
} else {
|
||||
self.mark(a);
|
||||
self.mark(orig);
|
||||
impl ProgressAnalysisPass for NonOblivComputer {
|
||||
fn visit(&mut self, term: &Term) -> bool {
|
||||
match &term.op {
|
||||
Op::Store => {
|
||||
let a = &term.cs[0];
|
||||
let i = &term.cs[1];
|
||||
let v = &term.cs[2];
|
||||
let mut progress = false;
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
if let Op::Const(_) = i.op {
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
} else {
|
||||
progress = self.mark(a) || progress;
|
||||
progress = self.mark(term) || progress;
|
||||
}
|
||||
if let Sort::Array(..) = check(v) {
|
||||
// Imprecisely, mark v as non-obliv iff the array is.
|
||||
progress = self.bi_implicate(term, v) || progress;
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Select => {
|
||||
let a = &term.cs[0];
|
||||
let i = &term.cs[1];
|
||||
if let Op::Const(_) = i.op {
|
||||
false
|
||||
} else {
|
||||
self.mark(a)
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
let t = &term.cs[1];
|
||||
let f = &term.cs[2];
|
||||
if let Sort::Array(..) = check(t) {
|
||||
let mut progress = self.bi_implicate(term, t);
|
||||
progress = self.bi_implicate(t, f) || progress;
|
||||
progress = self.bi_implicate(term, f) || progress;
|
||||
progress
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Op::Eq => {
|
||||
let a = &term.cs[0];
|
||||
let b = &term.cs[1];
|
||||
if let Sort::Array(..) = check(a) {
|
||||
self.bi_implicate(a, b)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Op::Tuple => {
|
||||
panic!("Tuple in obliv")
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn visit_select(&mut self, _orig: &Term, a: &Term, k: &Term) -> Option<Term> {
|
||||
if let Op::Const(_) = k.op {
|
||||
} else {
|
||||
self.mark(a);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressVisitor for NonOblivComputer {
|
||||
fn check_progress(&self) -> bool {
|
||||
self.progress
|
||||
}
|
||||
fn reset_progress(&mut self) {
|
||||
self.progress = false;
|
||||
}
|
||||
}
|
||||
|
||||
struct Replacer {
|
||||
/// A map from (original) replaced terms, to what they were replaced with.
|
||||
sequences: TermMap<Vec<Term>>,
|
||||
/// The maximum size of arrays that will be replaced.
|
||||
not_obliv: TermSet,
|
||||
}
|
||||
@@ -133,96 +145,120 @@ impl Replacer {
|
||||
!self.not_obliv.contains(a)
|
||||
}
|
||||
}
|
||||
fn arr_val_to_tup(v: &Value) -> Value {
|
||||
match v {
|
||||
Value::Array(Array {
|
||||
default, map, size, ..
|
||||
}) => Value::Tuple({
|
||||
let mut vec: Vec<Value> = vec![arr_val_to_tup(default); *size];
|
||||
for (i, v) in map {
|
||||
vec[i.as_usize().expect("non usize key")] = arr_val_to_tup(v);
|
||||
}
|
||||
vec
|
||||
}),
|
||||
v => v.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
impl MemVisitor for Replacer {
|
||||
fn visit_const_array(&mut self, orig: &Term, _key_sort: &Sort, val: &Term, size: usize) {
|
||||
if self.should_replace(orig) {
|
||||
debug!("Will replace constant: {}", orig);
|
||||
self.sequences
|
||||
.insert(orig.clone(), repeat(val).cloned().take(size).collect());
|
||||
}
|
||||
fn term_arr_val_to_tup(a: Term) -> Term {
|
||||
match &a.op {
|
||||
Op::Const(v @ Value::Array(..)) => leaf_term(Op::Const(arr_val_to_tup(v))),
|
||||
_ => a,
|
||||
}
|
||||
fn visit_eq(&mut self, orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
let b_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent eq");
|
||||
let eqs: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Eq; a.clone(), b.clone()])
|
||||
.collect();
|
||||
Some(term(Op::BoolNaryOp(BoolNaryOp::And), eqs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn arr_sort_to_tup(v: &Sort) -> Sort {
|
||||
match v {
|
||||
Sort::Array(_key, value, size) => Sort::Tuple(vec![arr_sort_to_tup(value); *size]),
|
||||
v => v.clone(),
|
||||
}
|
||||
fn visit_ite(&mut self, orig: &Term, c: &Term, _t: &Term, _f: &Term) {
|
||||
if self.should_replace(orig) {
|
||||
let a_seq = self.sequences.get(&orig.cs[1]).expect("inconsistent ite");
|
||||
let b_seq = self.sequences.get(&orig.cs[2]).expect("inconsistent ite");
|
||||
let ites: Vec<Term> = a_seq
|
||||
.iter()
|
||||
.zip(b_seq.iter())
|
||||
.map(|(a, b)| term![Op::Ite; c.clone(), a.clone(), b.clone()])
|
||||
.collect();
|
||||
self.sequences.insert(orig.clone(), ites);
|
||||
}
|
||||
}
|
||||
fn visit_store(&mut self, orig: &Term, _a: &Term, k: &Term, v: &Term) {
|
||||
if self.should_replace(orig) {
|
||||
let mut a_seq = self
|
||||
.sequences
|
||||
.get(&orig.cs[0])
|
||||
.expect("inconsistent store")
|
||||
.clone();
|
||||
let k_const = as_uint_constant(k)
|
||||
.expect("not obliv!")
|
||||
.to_usize()
|
||||
.expect("oversize index");
|
||||
a_seq[k_const] = v.clone();
|
||||
self.sequences.insert(orig.clone(), a_seq);
|
||||
}
|
||||
}
|
||||
fn visit_select(&mut self, orig: &Term, _a: &Term, k: &Term) -> Option<Term> {
|
||||
if let Some(a_seq) = self.sequences.get(&orig.cs[0]) {
|
||||
debug!("Will replace select: {}", orig);
|
||||
let k_const = as_uint_constant(k)
|
||||
.expect("not obliv!")
|
||||
.to_usize()
|
||||
.expect("oversize index");
|
||||
if k_const < a_seq.len() {
|
||||
Some(a_seq[k_const].clone())
|
||||
} else {
|
||||
panic!("Oversize index: {}", k_const)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn get_const(t: &Term) -> usize {
|
||||
as_uint_constant(t)
|
||||
.unwrap_or_else(|| panic!("non-const {}", t))
|
||||
.to_usize()
|
||||
.expect("oversize")
|
||||
}
|
||||
|
||||
impl RewritePass for Replacer {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
//debug!("Visit {}", extras::Letified(orig.clone()));
|
||||
let get_cs = || -> Vec<Term> {
|
||||
rewritten_children()
|
||||
.into_iter()
|
||||
.map(term_arr_val_to_tup)
|
||||
.collect()
|
||||
};
|
||||
match &orig.op {
|
||||
Op::Var(name, sort @ Sort::Array(_k, _v, _size)) => {
|
||||
if self.should_replace(orig) {
|
||||
let new_value = computation
|
||||
.values
|
||||
.as_ref()
|
||||
.map(|vs| arr_val_to_tup(vs.get(name).unwrap()));
|
||||
let vis = computation.metadata.get_input_visibility(name);
|
||||
let new_sort = arr_sort_to_tup(sort);
|
||||
let new_var_info = vec![(name.clone(), new_sort.clone(), new_value, vis)];
|
||||
computation.replace_input(orig.clone(), new_var_info);
|
||||
Some(leaf_term(Op::Var(name.clone(), new_sort)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
fn visit_var(&mut self, orig: &Term, name: &String, s: &Sort) {
|
||||
if let Sort::Array(_k, v, size) = s {
|
||||
if self.should_replace(orig) {
|
||||
self.sequences.insert(
|
||||
orig.clone(),
|
||||
(0..*size)
|
||||
.map(|i| leaf_term(Op::Var(format!("{}_{}", name, i), (**v).clone())))
|
||||
.collect(),
|
||||
);
|
||||
Op::Select => {
|
||||
if self.should_replace(&orig.cs[0]) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 2);
|
||||
let k_const = get_const(&cs.pop().unwrap());
|
||||
Some(term(Op::Field(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unreachable!("should only visit array vars")
|
||||
Op::Store => {
|
||||
if self.should_replace(&orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
Some(term(Op::Update(k_const), cs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
if self.should_replace(&orig) {
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Eq => {
|
||||
if self.should_replace(&orig.cs[0]) {
|
||||
Some(term(Op::Eq, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Eliminate oblivious arrays. See module documentation.
|
||||
pub fn elim_obliv(t: &Term) -> Term {
|
||||
pub fn elim_obliv(t: &mut Computation) {
|
||||
let mut prop_pass = NonOblivComputer::new();
|
||||
prop_pass.traverse_to_fixpoint(t);
|
||||
prop_pass.traverse(t);
|
||||
let mut replace_pass = Replacer {
|
||||
not_obliv: prop_pass.not_obliv,
|
||||
sequences: TermMap::new(),
|
||||
};
|
||||
replace_pass.traverse(t)
|
||||
<Replacer as RewritePass>::traverse(&mut replace_pass, t)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -244,7 +280,12 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn obliv() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
|
||||
let z = term![Op::Const(Value::Array(Array::new(
|
||||
Sort::BitVector(4),
|
||||
Box::new(Sort::BitVector(4).default_value()),
|
||||
Default::default(),
|
||||
6
|
||||
)))];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
@@ -253,13 +294,20 @@ mod test {
|
||||
],
|
||||
bv_lit(3, 4)
|
||||
];
|
||||
let tt = elim_obliv(&t);
|
||||
assert!(array_free(&tt));
|
||||
let mut c = Computation::default();
|
||||
c.outputs.push(t);
|
||||
elim_obliv(&mut c);
|
||||
assert!(array_free(&c.outputs[0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_obliv() {
|
||||
let z = term![Op::ConstArray(Sort::BitVector(4), 6); bv_lit(0, 4)];
|
||||
let z = term![Op::Const(Value::Array(Array::new(
|
||||
Sort::BitVector(4),
|
||||
Box::new(Sort::BitVector(4).default_value()),
|
||||
Default::default(),
|
||||
6
|
||||
)))];
|
||||
let t = term![Op::Select;
|
||||
term![Op::Ite;
|
||||
leaf_term(Op::Const(Value::Bool(true))),
|
||||
@@ -268,7 +316,9 @@ mod test {
|
||||
],
|
||||
bv_lit(3, 4)
|
||||
];
|
||||
let tt = elim_obliv(&t);
|
||||
assert!(!array_free(&tt));
|
||||
let mut c = Computation::default();
|
||||
c.outputs.push(t);
|
||||
elim_obliv(&mut c);
|
||||
assert!(!array_free(&c.outputs[0]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::debug;
|
||||
|
||||
/// A visitor for traversing terms, and visiting the array-related parts.
|
||||
///
|
||||
/// Visits:
|
||||
/// * EQs over arrays
|
||||
/// * Constant arrays
|
||||
/// * ITEs over arrays
|
||||
/// * array variables
|
||||
/// * STOREs
|
||||
/// * SELECTs
|
||||
///
|
||||
/// For the EQs and SELECTs, you have the ability to (optionally) return a replacement for the
|
||||
/// term. This can be used to "cut" the array out of the term, since EQs and SELECTs are how
|
||||
/// information leaves an array.
|
||||
///
|
||||
/// All visitors receive the original term and the rewritten children.
|
||||
pub trait MemVisitor {
|
||||
/// Visit a const array
|
||||
fn visit_const_array(&mut self, _orig: &Term, _key_sort: &Sort, _val: &Term, _size: usize) {}
|
||||
/// Visit an equality, whose children are `a` and `b`.
|
||||
fn visit_eq(&mut self, _orig: &Term, _a: &Term, _b: &Term) -> Option<Term> {
|
||||
None
|
||||
}
|
||||
/// Visit an array-valued ITE
|
||||
fn visit_ite(&mut self, _orig: &Term, _c: &Term, _t: &Term, _f: &Term) {}
|
||||
/// Visit a STORE
|
||||
fn visit_store(&mut self, _orig: &Term, _a: &Term, _k: &Term, _v: &Term) {}
|
||||
/// Visit a SELECT
|
||||
fn visit_select(&mut self, _orig: &Term, _a: &Term, _k: &Term) -> Option<Term> {
|
||||
None
|
||||
}
|
||||
fn visit_var(&mut self, _orig: &Term, _name: &String, _s: &Sort) {}
|
||||
|
||||
/// Traverse a node, visiting memory-related terms.
|
||||
///
|
||||
/// Can be used to remove memory-related terms by replacing the EQs and SELECTs which extract
|
||||
/// other terms from them.
|
||||
///
|
||||
/// Returns the transformed term.
|
||||
fn traverse(&mut self, node: &Term) -> Term {
|
||||
let mut cache = TermMap::<Term>::new();
|
||||
for t in PostOrderIter::new(node.clone()) {
|
||||
let c_get = |x: &Term| cache.get(x).unwrap();
|
||||
let get = |i: usize| c_get(&t.cs[i]);
|
||||
let new_t_opt = {
|
||||
let s = check(&t);
|
||||
match &s {
|
||||
Sort::Array(_, _, _) => {
|
||||
match &t.op {
|
||||
Op::Var(name, s) => {
|
||||
self.visit_var(&t, name, s);
|
||||
}
|
||||
Op::Ite => {
|
||||
self.visit_ite(&t, get(0), get(1), get(2));
|
||||
}
|
||||
Op::Store => {
|
||||
self.visit_store(&t, get(0), get(1), get(2));
|
||||
}
|
||||
Op::ConstArray(s, n) => {
|
||||
self.visit_const_array(&t, s, get(0), *n);
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
None
|
||||
}
|
||||
_ => match &t.op {
|
||||
Op::Eq => {
|
||||
let a = get(0);
|
||||
if let Sort::Array(_, _, _) = check(&a) {
|
||||
self.visit_eq(&t, a, get(1))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Select => self.visit_select(&t, get(0), get(1)),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
};
|
||||
let new_t = new_t_opt.unwrap_or_else(|| {
|
||||
term(
|
||||
t.op.clone(),
|
||||
t.cs.iter().map(|c| c_get(c).clone()).collect(),
|
||||
)
|
||||
});
|
||||
debug!("rebuild: {}", new_t);
|
||||
cache.insert(t.clone(), new_t);
|
||||
}
|
||||
cache.remove(node).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProgressVisitor: MemVisitor {
|
||||
fn reset_progress(&mut self);
|
||||
fn check_progress(&self) -> bool;
|
||||
|
||||
fn traverse_to_fixpoint(&mut self, a: &Term) {
|
||||
self.traverse(a);
|
||||
while self.check_progress() {
|
||||
self.reset_progress();
|
||||
self.traverse(a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ pub mod cfold;
|
||||
pub mod flat;
|
||||
pub mod inline;
|
||||
pub mod mem;
|
||||
pub mod scalarize_vars;
|
||||
pub mod sha;
|
||||
pub mod tuple;
|
||||
mod visit;
|
||||
|
||||
use super::term::*;
|
||||
use log::debug;
|
||||
@@ -12,14 +14,19 @@ use log::debug;
|
||||
#[derive(Debug)]
|
||||
/// An optimization pass
|
||||
pub enum Opt {
|
||||
/// Convert non-scalar (tuple, array) inputs to scalar ones
|
||||
/// The scalar variable names are suffixed with .N, where N indicates the array/tuple position
|
||||
ScalarizeVars,
|
||||
/// Fold constants
|
||||
ConstantFold,
|
||||
/// Flatten n-ary operators
|
||||
Flatten,
|
||||
/// SHA-2 peephole optimizations
|
||||
Sha,
|
||||
/// Memory elimination
|
||||
Mem,
|
||||
/// Replace oblivious arrays with tuples
|
||||
Obliv,
|
||||
/// Replace arrays with linear scans
|
||||
LinearScan,
|
||||
/// Extract top-level ANDs as distinct outputs
|
||||
FlattenAssertions,
|
||||
/// Find outputs like `(= variable term)`, and substitute out `variable`
|
||||
@@ -33,6 +40,9 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
|
||||
for i in optimizations {
|
||||
debug!("Applying: {:?}", i);
|
||||
match i {
|
||||
Opt::ScalarizeVars => {
|
||||
scalarize_vars::scalarize_inputs(&mut cs);
|
||||
}
|
||||
Opt::ConstantFold => {
|
||||
let mut cache = TermMap::new();
|
||||
for a in &mut cs.outputs {
|
||||
@@ -44,10 +54,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
|
||||
*a = sha::sha_rewrites(a);
|
||||
}
|
||||
}
|
||||
Opt::Mem => {
|
||||
for a in &mut cs.outputs {
|
||||
*a = mem::array_elim(a);
|
||||
}
|
||||
Opt::Obliv => {
|
||||
mem::obliv::elim_obliv(&mut cs);
|
||||
}
|
||||
Opt::LinearScan => {
|
||||
mem::lin::linearize(&mut cs);
|
||||
}
|
||||
Opt::FlattenAssertions => {
|
||||
let mut new_outputs = Vec::new();
|
||||
@@ -68,14 +79,19 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computation, optimizations: I) -
|
||||
}
|
||||
}
|
||||
Opt::Inline => {
|
||||
let public_inputs = cs.metadata.public_inputs().map(ToOwned::to_owned).collect();
|
||||
let public_inputs = cs
|
||||
.metadata
|
||||
.public_input_names()
|
||||
.map(ToOwned::to_owned)
|
||||
.collect();
|
||||
inline::inline(&mut cs.outputs, &public_inputs);
|
||||
}
|
||||
Opt::Tuple => {
|
||||
cs = tuple::eliminate_tuples(cs);
|
||||
tuple::eliminate_tuples(&mut cs);
|
||||
}
|
||||
}
|
||||
debug!("After {:?}: {} outputs", i, cs.outputs.len());
|
||||
//debug!("After {:?}: {}", i, Letified(cs.outputs[0].clone()));
|
||||
debug!("After {:?}: {} terms", i, cs.terms());
|
||||
}
|
||||
garbage_collect();
|
||||
|
||||
132
src/ir/opt/scalarize_vars.rs
Normal file
132
src/ir/opt/scalarize_vars.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
//! Replacing array and tuple variables with scalars.
|
||||
use crate::ir::opt::visit::RewritePass;
|
||||
use crate::ir::term::*;
|
||||
|
||||
struct Pass;
|
||||
|
||||
fn create_vars(
|
||||
prefix: &str,
|
||||
sort: &Sort,
|
||||
value: Option<Value>,
|
||||
party: Option<PartyId>,
|
||||
new_var_requests: &mut Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
|
||||
) -> Term {
|
||||
match sort {
|
||||
Sort::Tuple(sorts) => {
|
||||
let mut values = value.map(|v| match v {
|
||||
Value::Tuple(t) => t,
|
||||
_ => panic!(),
|
||||
});
|
||||
term(
|
||||
Op::Tuple,
|
||||
sorts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, sort)| {
|
||||
create_vars(
|
||||
&format!("{}.{}", prefix, i),
|
||||
sort,
|
||||
values
|
||||
.as_mut()
|
||||
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
|
||||
party,
|
||||
new_var_requests,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
Sort::Array(key_s, val_s, size) => {
|
||||
let mut values = value.map(|v| match v {
|
||||
Value::Array(Array {
|
||||
default, map, size, ..
|
||||
}) => {
|
||||
let mut vals = vec![*default; size];
|
||||
for (key_val, val_val) in map.into_iter() {
|
||||
let idx = key_val.as_usize().unwrap();
|
||||
vals[idx] = val_val;
|
||||
}
|
||||
vals
|
||||
}
|
||||
_ => panic!(),
|
||||
});
|
||||
make_array(
|
||||
(**key_s).clone(),
|
||||
(**val_s).clone(),
|
||||
(0..*size)
|
||||
.map(|i| {
|
||||
create_vars(
|
||||
&format!("{}.{}", prefix, i),
|
||||
val_s,
|
||||
values
|
||||
.as_mut()
|
||||
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
|
||||
party,
|
||||
new_var_requests,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
new_var_requests.push((prefix.into(), sort.clone(), value, party));
|
||||
leaf_term(Op::Var(prefix.into(), sort.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RewritePass for Pass {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
_rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
if let Op::Var(name, sort) = &orig.op {
|
||||
let party_visibility = computation.metadata.get_input_visibility(name);
|
||||
let mut new_var_reqs = Vec::new();
|
||||
let new = create_vars(
|
||||
name,
|
||||
sort,
|
||||
computation
|
||||
.values
|
||||
.as_ref()
|
||||
.map(|v| v.get(name).unwrap().clone()),
|
||||
party_visibility,
|
||||
&mut new_var_reqs,
|
||||
);
|
||||
if new_var_reqs.len() > 0 {
|
||||
computation.replace_input(orig.clone(), new_var_reqs);
|
||||
}
|
||||
Some(new)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the tuple elimination pass.
|
||||
pub fn scalarize_inputs(cs: &mut Computation) {
|
||||
let mut pass = Pass;
|
||||
pass.traverse(cs);
|
||||
#[cfg(debug_assertions)]
|
||||
assert_all_vars_are_scalars(cs);
|
||||
}
|
||||
|
||||
/// Check that every variables is a scalar.
|
||||
pub fn assert_all_vars_are_scalars(cs: &Computation) {
|
||||
for t in cs
|
||||
.terms_postorder()
|
||||
.into_iter()
|
||||
.chain(cs.metadata.inputs.iter().cloned())
|
||||
{
|
||||
if let Op::Var(_name, sort) = &t.op {
|
||||
match sort {
|
||||
Sort::Array(..) | Sort::Tuple(..) => {
|
||||
panic!("Variable {} is non-scalar", t);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,9 +170,9 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn undo() {
|
||||
let a = bv_lit(0, 1);
|
||||
let b = bv_lit(0, 1);
|
||||
let c = bv_lit(0, 1);
|
||||
let a = bool_lit(false);
|
||||
let b = bool_lit(false);
|
||||
let c = bool_lit(false);
|
||||
let t = term![Op::BoolMaj; a.clone(), b.clone(),c.clone()];
|
||||
let tt = term![OR; term![AND; a.clone(), b.clone()], term![AND; b.clone(), c.clone()], term![AND; c.clone(), a.clone()]];
|
||||
assert_eq!(tt, sha_maj_elim(&t));
|
||||
|
||||
@@ -2,225 +2,268 @@
|
||||
//!
|
||||
//! Elimates tuple-related terms.
|
||||
//!
|
||||
//! The idea is to do a bottom-up pass mapping all terms to tuple-free trees of terms.
|
||||
//! The idea is to do a bottom-up pass, in which all tuple's are lift to the top level, and then
|
||||
//! then removed.
|
||||
//!
|
||||
//! * Tuple variables are suffixed. `x: (bool, bool)` becomes `x.0: bool, x.1: bool`.
|
||||
//! * Tuple constants are replaced with trees
|
||||
//! * Tuple constructors make trees
|
||||
//! * Tuple accesses open up trees
|
||||
//! * Tuple ITEs yield trees of ITEs
|
||||
//! * Tuple EQs yield conjunctions of EQs
|
||||
|
||||
use std::rc::Rc;
|
||||
//! ## Phase 1
|
||||
//!
|
||||
//! Phase 1 (lifting tuples) is defined by the following big-step rewrites.
|
||||
//!
|
||||
//! Notational conventions:
|
||||
//! * lowercase letters are used to match sorts/terms before rewriting.
|
||||
//! * uppercase letters denote their (big-step) rewritten counterparts.
|
||||
//! * () denote AST structure
|
||||
//! * [] denote repeated structures, i.e. like var-args.
|
||||
//! * `f(x, *)` denotes a partial application of `f`, i.e. the function that sends `y` to `f(x,y)`.
|
||||
//! * In the results of rewriting we often have terms which are tuples at the top-level. I.e. their
|
||||
//! sort contains no tuple sort that is not the child of a tuple sort. Similarly, the terms are
|
||||
//! tuples at the top-level: no field or update operators are present, and the only tuple
|
||||
//! operators are the children of other tuple operators.
|
||||
//! * Such sorts/terms can be viewed as structured collections of non-tuple elements, i.e., as
|
||||
//! functors whose pure elements are non-tuples.
|
||||
//! * The `map`, `bimap`, and `list` functions apply to those functors.
|
||||
//! * i.e., `map f (tuple x tuple y z))` is `(tuple (f x) (tuple (f y) (f z)))`.
|
||||
//! * i.e., `list (tuple x tuple y z))` is `(x y z)`
|
||||
//!
|
||||
//! Assumptions:
|
||||
//! * We assume that array keys are of scalar sort
|
||||
//! * We assume that variables are of scalar sort. See [super::scalarize_vars].
|
||||
//! * We *do not* describe the pass as applied to constant values. That part of the pass is
|
||||
//! entirely analagous to terms.
|
||||
//!
|
||||
//! Sort rewrites:
|
||||
//! * `(tuple [t_i]_i) -> (tuple [T_i]_i)`
|
||||
//! * `(array k t) -> map (array k *) T
|
||||
//!
|
||||
//! Term rewrites:
|
||||
//! * `(ite c t f) -> bimap (ite C * *) T F`
|
||||
//! * `(eq c t f) -> (and (list (bimap (= * *) T F)))`
|
||||
//! * `(tuple [t_i]_i) -> (tuple [T_i]_i)`
|
||||
//! * `(field_j t) -> get T j`
|
||||
//! * `(update_j t) -> update T j V`
|
||||
//! * `(select a i) -> map (select * I) A`
|
||||
//! * `(store a i v) -> bimap (store * I *) A V`
|
||||
//! * `(OTHER [t_i]_i) -> (OTHER [T_i]_i)`
|
||||
//! * constants: *omitted*
|
||||
//!
|
||||
//! The result of this phase is a computation whose only tuple-terms are at the top of the
|
||||
//! computation graph
|
||||
//!
|
||||
//! ## Phase 2
|
||||
//!
|
||||
//! We replace each output `t` with the sequence of outputs `(list T)`.
|
||||
|
||||
use crate::ir::opt::visit::RewritePass;
|
||||
use crate::ir::term::{
|
||||
check, leaf_term, term, BoolNaryOp, Computation, Op, PartyId, PostOrderIter, Sort, Term,
|
||||
TermMap, Value,
|
||||
check, extras, leaf_term, term, Array, Computation, Op, PostOrderIter, Sort, Term, Value, AND,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
type Tree = Rc<TreeData>;
|
||||
use itertools::zip_eq;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum TreeData {
|
||||
Leaf(Term),
|
||||
Tuple(Vec<Tree>),
|
||||
}
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
struct TupleTree(Term);
|
||||
|
||||
impl TreeData {
|
||||
fn unwrap_leaf(&self) -> &Term {
|
||||
match self {
|
||||
TreeData::Leaf(l) => l,
|
||||
d => panic!("expected leaf, got {:?}", d),
|
||||
}
|
||||
}
|
||||
fn unwrap_tuple(&self) -> &Vec<Tree> {
|
||||
match self {
|
||||
TreeData::Tuple(l) => l,
|
||||
d => panic!("expected tuple, got {:?}", d),
|
||||
}
|
||||
}
|
||||
fn unfold_tuple_into(&self, terms: &mut Vec<Term>) {
|
||||
match self {
|
||||
TreeData::Leaf(l) => terms.push(l.clone()),
|
||||
TreeData::Tuple(l) => l.iter().for_each(|x| x.unfold_tuple_into(terms)),
|
||||
}
|
||||
}
|
||||
fn unfold_tuple(&self) -> Vec<Term> {
|
||||
let mut terms = Vec::new();
|
||||
self.unfold_tuple_into(&mut terms);
|
||||
terms
|
||||
}
|
||||
fn from_value(v: Value) -> TreeData {
|
||||
match v {
|
||||
Value::Tuple(vs) => {
|
||||
TreeData::Tuple(vs.into_iter().map(Self::from_value).map(Rc::new).collect())
|
||||
impl TupleTree {
|
||||
fn flatten(&self) -> impl Iterator<Item = Term> {
|
||||
let mut out = Vec::new();
|
||||
fn rec_unroll_into(t: &Term, out: &mut Vec<Term>) {
|
||||
if &t.op == &Op::Tuple {
|
||||
for c in &t.cs {
|
||||
rec_unroll_into(c, out);
|
||||
}
|
||||
} else {
|
||||
out.push(t.clone());
|
||||
}
|
||||
v => TreeData::Leaf(leaf_term(Op::Const(v))),
|
||||
}
|
||||
rec_unroll_into(&self.0, &mut out);
|
||||
out.into_iter()
|
||||
}
|
||||
fn structure(&self, flattened: impl IntoIterator<Item = Term>) -> Self {
|
||||
fn term_structure(t: &Term, iter: &mut impl Iterator<Item = Term>) -> Term {
|
||||
if &t.op == &Op::Tuple {
|
||||
term(
|
||||
Op::Tuple,
|
||||
t.cs.iter().map(|c| term_structure(c, iter)).collect(),
|
||||
)
|
||||
} else {
|
||||
iter.next().expect("bad structure")
|
||||
}
|
||||
}
|
||||
Self(term_structure(&self.0, &mut flattened.into_iter()))
|
||||
}
|
||||
fn well_formed(&self) -> bool {
|
||||
for t in PostOrderIter::new(self.0.clone()) {
|
||||
if &t.op != &Op::Tuple {
|
||||
for c in &t.cs {
|
||||
if &c.op == &Op::Tuple {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
fn assert_well_formed(&self) {
|
||||
assert!(
|
||||
self.well_formed(),
|
||||
"The following is not a well-formed tuple tree {}",
|
||||
extras::Letified(self.0.clone())
|
||||
);
|
||||
}
|
||||
fn map(&self, f: impl FnMut(Term) -> Term) -> Self {
|
||||
self.structure(self.flatten().map(f))
|
||||
}
|
||||
fn bimap(&self, mut f: impl FnMut(Term, Term) -> Term, other: &Self) -> Self {
|
||||
self.structure(itertools::zip_eq(self.flatten(), other.flatten()).map(|(a, b)| f(a, b)))
|
||||
}
|
||||
fn get(&self, i: usize) -> Self {
|
||||
assert_eq!(&self.0.op, &Op::Tuple);
|
||||
assert!(i < self.0.cs.len());
|
||||
Self(self.0.cs[i].clone())
|
||||
}
|
||||
fn update(&self, i: usize, v: &Term) -> Self {
|
||||
assert_eq!(&self.0.op, &Op::Tuple);
|
||||
assert!(i < self.0.cs.len());
|
||||
let mut cs = self.0.cs.clone();
|
||||
cs[i] = v.clone();
|
||||
Self(term(Op::Tuple, cs))
|
||||
}
|
||||
}
|
||||
|
||||
fn restructure(sort: &Sort, items: &mut impl Iterator<Item = Term>) -> Tree {
|
||||
if let Sort::Tuple(sorts) = sort {
|
||||
Rc::new(TreeData::Tuple(
|
||||
sorts.iter().map(|c| restructure(c, items)).collect(),
|
||||
))
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
struct ValueTupleTree(Value);
|
||||
|
||||
impl ValueTupleTree {
|
||||
fn flatten(&self) -> Vec<Value> {
|
||||
let mut out = Vec::new();
|
||||
fn rec_unroll_into(t: &Value, out: &mut Vec<Value>) {
|
||||
match t {
|
||||
Value::Tuple(vs) => {
|
||||
for c in vs {
|
||||
rec_unroll_into(c, out);
|
||||
}
|
||||
}
|
||||
_ => out.push(t.clone()),
|
||||
}
|
||||
}
|
||||
rec_unroll_into(&self.0, &mut out);
|
||||
out
|
||||
}
|
||||
fn structure(&self, flattened: impl IntoIterator<Item = Value>) -> Self {
|
||||
fn term_structure(t: &Value, iter: &mut impl Iterator<Item = Value>) -> Value {
|
||||
match t {
|
||||
Value::Tuple(vs) => {
|
||||
Value::Tuple(vs.iter().map(|c| term_structure(c, iter)).collect())
|
||||
}
|
||||
_ => iter.next().expect("bad structure"),
|
||||
}
|
||||
}
|
||||
Self(term_structure(&self.0, &mut flattened.into_iter()))
|
||||
}
|
||||
}
|
||||
|
||||
fn termify_val_tuples(v: Value) -> Term {
|
||||
if let Value::Tuple(vs) = v {
|
||||
term(Op::Tuple, vs.into_iter().map(termify_val_tuples).collect())
|
||||
} else {
|
||||
Rc::new(TreeData::Leaf(
|
||||
items
|
||||
.next()
|
||||
.unwrap_or_else(|| panic!("No term when building {}", sort)),
|
||||
))
|
||||
leaf_term(Op::Const(v))
|
||||
}
|
||||
}
|
||||
|
||||
struct Pass {
|
||||
map: TermMap<Tree>,
|
||||
cs: Computation,
|
||||
}
|
||||
|
||||
impl Pass {
|
||||
fn new(cs: Computation) -> Self {
|
||||
Self {
|
||||
map: TermMap::new(),
|
||||
cs,
|
||||
fn untuple_value(v: &Value) -> Value {
|
||||
match v {
|
||||
Value::Tuple(xs) => Value::Tuple(xs.iter().map(untuple_value).collect()),
|
||||
Value::Array(Array {
|
||||
key_sort,
|
||||
default,
|
||||
map,
|
||||
size,
|
||||
}) => {
|
||||
let def = untuple_value(default);
|
||||
let flat_def = ValueTupleTree(def.clone()).flatten();
|
||||
let mut map: BTreeMap<_, _> = map
|
||||
.iter()
|
||||
.map(|(k, v)| (k, ValueTupleTree(untuple_value(v)).flatten()))
|
||||
.collect();
|
||||
let mut flat_out: Vec<Value> = flat_def
|
||||
.into_iter()
|
||||
.rev()
|
||||
.map(|d| {
|
||||
let mut submap: BTreeMap<Value, Value> = BTreeMap::new();
|
||||
for (k, v) in &mut map {
|
||||
submap.insert((**k).clone(), v.pop().unwrap());
|
||||
}
|
||||
Value::Array(Array::new(key_sort.clone(), Box::new(d), submap, *size))
|
||||
})
|
||||
.collect();
|
||||
flat_out.reverse();
|
||||
ValueTupleTree(def).structure(flat_out).0
|
||||
}
|
||||
_ => v.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn get_tree(&self, t: &Term) -> &Tree {
|
||||
self.map
|
||||
.get(t)
|
||||
.unwrap_or_else(|| panic!("missing tree for term: {} in {:?}", t, self.map))
|
||||
}
|
||||
struct TupleLifter;
|
||||
|
||||
fn create_vars(
|
||||
impl RewritePass for TupleLifter {
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
prefix: &str,
|
||||
sort: &Sort,
|
||||
value: Option<Value>,
|
||||
party: Option<PartyId>,
|
||||
) -> Tree {
|
||||
match sort {
|
||||
Sort::Tuple(sorts) => {
|
||||
let mut values = value.map(|v| match v {
|
||||
Value::Tuple(t) => t,
|
||||
_ => panic!(),
|
||||
});
|
||||
Rc::new(TreeData::Tuple(
|
||||
sorts
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, sort)| {
|
||||
self.create_vars(
|
||||
&format!("{}.{}", prefix, i),
|
||||
sort,
|
||||
values
|
||||
.as_mut()
|
||||
.map(|v| std::mem::replace(&mut v[i], Value::Bool(true))),
|
||||
party,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
_computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
match &orig.op {
|
||||
Op::Const(v) => Some(termify_val_tuples(untuple_value(v))),
|
||||
Op::Ite => {
|
||||
let mut cs = rewritten_children();
|
||||
let f = TupleTree(cs.pop().unwrap());
|
||||
let t = TupleTree(cs.pop().unwrap());
|
||||
let c = cs.pop().unwrap();
|
||||
debug_assert!(cs.is_empty());
|
||||
Some(t.bimap(|a, b| term![Op::Ite; c.clone(), a, b], &f).0)
|
||||
}
|
||||
_ => Rc::new(TreeData::Leaf(
|
||||
if self.cs.metadata.is_input(prefix)
|
||||
&& self.cs.metadata.is_input_public(prefix)
|
||||
&& self
|
||||
.cs
|
||||
.values
|
||||
.as_ref()
|
||||
.map(|v| v.contains_key(prefix))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
leaf_term(Op::Var(prefix.into(), sort.clone()))
|
||||
} else {
|
||||
self.cs
|
||||
.new_var(prefix, sort.clone(), || value.unwrap().clone(), party)
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn embed_step(&mut self, t: &Term) {
|
||||
let s = check(t);
|
||||
let tree = if let Sort::Tuple(_) = s {
|
||||
match &t.op {
|
||||
Op::Const(v) => Rc::new(TreeData::from_value(v.clone())),
|
||||
Op::Var(name, sort) => {
|
||||
let party_visibility = self.cs.metadata.get_input_visibility(name);
|
||||
self.create_vars(
|
||||
name,
|
||||
sort,
|
||||
self.cs
|
||||
.values
|
||||
.as_ref()
|
||||
.map(|v| v.get(name).unwrap().clone()),
|
||||
party_visibility,
|
||||
)
|
||||
}
|
||||
Op::Tuple => Rc::new(TreeData::Tuple(
|
||||
t.cs.iter().map(|c| self.get_tree(c).clone()).collect(),
|
||||
)),
|
||||
Op::Ite => {
|
||||
let cond = self.get_tree(&t.cs[0]).unwrap_leaf();
|
||||
let trues = self.get_tree(&t.cs[1]).unfold_tuple();
|
||||
let falses = self.get_tree(&t.cs[2]).unfold_tuple();
|
||||
restructure(
|
||||
&s,
|
||||
&mut trues
|
||||
.into_iter()
|
||||
.zip(falses.into_iter())
|
||||
.map(|(a, b)| term![Op::Ite; cond.clone(), a, b]),
|
||||
)
|
||||
}
|
||||
Op::Field(i) => self.get_tree(&t.cs[0]).unwrap_tuple()[*i].clone(),
|
||||
o => panic!("Bad tuple operator: {}", o),
|
||||
Op::Eq => {
|
||||
let mut cs = rewritten_children();
|
||||
let b = TupleTree(cs.pop().unwrap());
|
||||
let a = TupleTree(cs.pop().unwrap());
|
||||
debug_assert!(cs.is_empty());
|
||||
let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]);
|
||||
Some(term(AND, eqs.collect()))
|
||||
}
|
||||
} else {
|
||||
match &t.op {
|
||||
Op::Field(i) => self.get_tree(&t.cs[0]).unwrap_tuple()[*i].clone(),
|
||||
Op::Eq => {
|
||||
let c_sort = check(&t.cs[0]);
|
||||
Rc::new(TreeData::Leaf(if let Sort::Tuple(_) = c_sort {
|
||||
let xs = self.get_tree(&t.cs[0]).unfold_tuple();
|
||||
let ys = self.get_tree(&t.cs[1]).unfold_tuple();
|
||||
term(
|
||||
Op::BoolNaryOp(BoolNaryOp::And),
|
||||
xs.into_iter()
|
||||
.zip(ys.into_iter())
|
||||
.map(|(x, y)| term![Op::Eq; x, y])
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
term(
|
||||
t.op.clone(),
|
||||
t.cs.iter()
|
||||
.map(|c| self.get_tree(c).unwrap_leaf().clone())
|
||||
.collect(),
|
||||
)
|
||||
}))
|
||||
}
|
||||
_ => Rc::new(TreeData::Leaf(term(
|
||||
t.op.clone(),
|
||||
t.cs.iter()
|
||||
.map(|c| self.get_tree(c).unwrap_leaf().clone())
|
||||
.collect(),
|
||||
))),
|
||||
Op::Store => {
|
||||
let mut cs = rewritten_children();
|
||||
let v = TupleTree(cs.pop().unwrap());
|
||||
let i = cs.pop().unwrap();
|
||||
let a = TupleTree(cs.pop().unwrap());
|
||||
debug_assert!(cs.is_empty());
|
||||
Some(a.bimap(|a, v| term![Op::Store; a, i.clone(), v], &v).0)
|
||||
}
|
||||
};
|
||||
if let TreeData::Leaf(t) = &*tree {
|
||||
debug_assert!(tuple_free(t.clone()), "Tuple in {:?}", tree);
|
||||
Op::Select => {
|
||||
let mut cs = rewritten_children();
|
||||
let i = cs.pop().unwrap();
|
||||
let a = TupleTree(cs.pop().unwrap());
|
||||
debug_assert!(cs.is_empty());
|
||||
Some(a.map(|a| term![Op::Select; a, i.clone()]).0)
|
||||
}
|
||||
Op::Field(i) => {
|
||||
let mut cs = rewritten_children();
|
||||
let t = TupleTree(cs.pop().unwrap());
|
||||
debug_assert!(cs.is_empty());
|
||||
Some(t.get(*i).0)
|
||||
}
|
||||
Op::Update(i) => {
|
||||
let mut cs = rewritten_children();
|
||||
let v = cs.pop().unwrap();
|
||||
let t = TupleTree(cs.pop().unwrap());
|
||||
debug_assert!(cs.is_empty());
|
||||
Some(t.update(*i, &v).0)
|
||||
}
|
||||
// The default rewrite is correct here.
|
||||
Op::Tuple => None,
|
||||
_ => None,
|
||||
}
|
||||
self.map.insert(t.clone(), tree);
|
||||
}
|
||||
|
||||
fn embed(&mut self, t: &Term) -> Tree {
|
||||
for c in PostOrderIter::new(t.clone()) {
|
||||
self.embed_step(&c);
|
||||
}
|
||||
self.get_tree(t).clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,16 +275,15 @@ fn tuple_free(t: Term) -> bool {
|
||||
}
|
||||
|
||||
/// Run the tuple elimination pass.
|
||||
pub fn eliminate_tuples(mut cs: Computation) -> Computation {
|
||||
let outputs = std::mem::take(&mut cs.outputs);
|
||||
let mut pass = Pass::new(cs);
|
||||
for output in &outputs {
|
||||
pass.embed(&output);
|
||||
}
|
||||
let new_ouputs: Vec<Term> = outputs
|
||||
.iter()
|
||||
.map(|c| pass.get_tree(c).unwrap_leaf().clone())
|
||||
pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
let mut pass = TupleLifter;
|
||||
pass.traverse(cs);
|
||||
cs.outputs = std::mem::take(&mut cs.outputs)
|
||||
.into_iter()
|
||||
.flat_map(|o| TupleTree(o).flatten())
|
||||
.collect();
|
||||
pass.cs.outputs = new_ouputs;
|
||||
pass.cs
|
||||
#[cfg(debug_assertions)]
|
||||
for o in &cs.outputs {
|
||||
assert!(tuple_free(o.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
76
src/ir/opt/visit.rs
Normal file
76
src/ir/opt/visit.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// A rewriting pass.
|
||||
pub trait RewritePass {
|
||||
/// Visit (and possibly rewrite) a term.
|
||||
/// Given the original term and a function to get its rewritten childen.
|
||||
/// Returns a term if a rewrite happens.
|
||||
fn visit<F: Fn() -> Vec<Term>>(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term>;
|
||||
fn traverse(&mut self, computation: &mut Computation) {
|
||||
let mut cache = TermMap::<Term>::new();
|
||||
let mut children_added = TermSet::new();
|
||||
let mut stack = Vec::new();
|
||||
stack.extend(computation.outputs.iter().cloned());
|
||||
while let Some(top) = stack.pop() {
|
||||
if !cache.contains_key(&top) {
|
||||
// was it missing?
|
||||
if children_added.insert(top.clone()) {
|
||||
stack.push(top.clone());
|
||||
stack.extend(top.cs.iter().filter(|c| !cache.contains_key(c)).cloned());
|
||||
} else {
|
||||
let get_children = || -> Vec<Term> {
|
||||
top.cs
|
||||
.iter()
|
||||
.map(|c| cache.get(c).unwrap())
|
||||
.cloned()
|
||||
.collect()
|
||||
};
|
||||
let new_t_opt = self.visit(computation, &top, get_children);
|
||||
let new_t = new_t_opt.unwrap_or_else(|| term(top.op.clone(), get_children()));
|
||||
cache.insert(top.clone(), new_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
computation.outputs = computation
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| cache.get(o).unwrap().clone())
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
/// An analysis pass that repeated sweeps all terms, visiting them, untill a pass makes no more
|
||||
/// progress.
|
||||
pub trait ProgressAnalysisPass {
|
||||
/// The visit function. Returns whether progress was made.
|
||||
fn visit(&mut self, term: &Term) -> bool;
|
||||
/// Repeatedly sweep till progress is no longer made.
|
||||
fn traverse(&mut self, computation: &Computation) {
|
||||
let mut progress = true;
|
||||
let mut order = Vec::new();
|
||||
let mut visited = TermSet::new();
|
||||
let mut stack = Vec::new();
|
||||
stack.extend(computation.outputs.iter().cloned());
|
||||
while let Some(top) = stack.pop() {
|
||||
stack.extend(top.cs.iter().filter(|c| !visited.contains(c)).cloned());
|
||||
// was it missing?
|
||||
if visited.insert(top.clone()) {
|
||||
order.push(top);
|
||||
}
|
||||
}
|
||||
while progress {
|
||||
progress = false;
|
||||
for t in &order {
|
||||
progress = self.visit(t) || progress;
|
||||
}
|
||||
for t in order.iter().rev() {
|
||||
progress = self.visit(t) || progress;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,16 +119,6 @@ pub enum Op {
|
||||
/// Takes the modulus.
|
||||
UbvToPf(Arc<Integer>),
|
||||
|
||||
// key sort, size
|
||||
/// A unary operator.
|
||||
///
|
||||
/// Make an array from keys of the given sort, which is equal to the provided argument at all
|
||||
/// places.
|
||||
///
|
||||
/// Has space for the provided number of elements. Note that this assumes an order and starting
|
||||
/// point for keys.
|
||||
ConstArray(Sort, usize),
|
||||
|
||||
/// Binary operator, with arguments (array, index).
|
||||
///
|
||||
/// Gets the value at index in array.
|
||||
@@ -142,6 +132,8 @@ pub enum Op {
|
||||
Tuple,
|
||||
/// Get the n'th element of a tuple
|
||||
Field(usize),
|
||||
/// Update (tuple, element)
|
||||
Update(usize),
|
||||
}
|
||||
|
||||
/// Boolean AND
|
||||
@@ -247,11 +239,11 @@ impl Op {
|
||||
Op::PfUnOp(_) => Some(1),
|
||||
Op::PfNaryOp(_) => None,
|
||||
Op::UbvToPf(_) => Some(1),
|
||||
Op::ConstArray(_, _) => Some(1),
|
||||
Op::Select => Some(2),
|
||||
Op::Store => Some(3),
|
||||
Op::Tuple => None,
|
||||
Op::Field(_) => Some(1),
|
||||
Op::Update(_) => Some(2),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -289,11 +281,11 @@ impl Display for Op {
|
||||
Op::PfUnOp(a) => write!(f, "{}", a),
|
||||
Op::PfNaryOp(a) => write!(f, "{}", a),
|
||||
Op::UbvToPf(a) => write!(f, "bv2pf {}", a),
|
||||
Op::ConstArray(_, s) => write!(f, "const-array {}", s),
|
||||
Op::Select => write!(f, "select"),
|
||||
Op::Store => write!(f, "store"),
|
||||
Op::Tuple => write!(f, "tuple"),
|
||||
Op::Field(i) => write!(f, "field{}", i),
|
||||
Op::Update(i) => write!(f, "update{}", i),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -629,11 +621,61 @@ pub enum Value {
|
||||
/// Boolean
|
||||
Bool(bool),
|
||||
/// Array
|
||||
Array(Sort, Box<Value>, BTreeMap<Value, Value>, usize),
|
||||
Array(Array),
|
||||
/// Tuple
|
||||
Tuple(Vec<Value>),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, PartialOrd, Hash)]
|
||||
/// An IR array value.
|
||||
///
|
||||
/// A sized, space array.
|
||||
pub struct Array {
|
||||
/// Key sort
|
||||
pub key_sort: Sort,
|
||||
/// Default (fill) value. What is stored when a key is missing from the next member
|
||||
pub default: Box<Value>,
|
||||
/// Key-> Value map
|
||||
pub map: BTreeMap<Value, Value>,
|
||||
/// Size of array. There are this many valid keys.
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl Array {
|
||||
/// Create a new [Array] from components
|
||||
pub fn new(
|
||||
key_sort: Sort,
|
||||
default: Box<Value>,
|
||||
map: BTreeMap<Value, Value>,
|
||||
size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
key_sort,
|
||||
default,
|
||||
map,
|
||||
size,
|
||||
}
|
||||
}
|
||||
/// Create a new, default-initialized [Array]
|
||||
pub fn default(key_sort: Sort, val_sort: &Sort, size: usize) -> Self {
|
||||
Self::new(
|
||||
key_sort,
|
||||
Box::new(val_sort.default_value()),
|
||||
Default::default(),
|
||||
size,
|
||||
)
|
||||
}
|
||||
/// Store
|
||||
pub fn store(mut self, idx: Value, val: Value) -> Self {
|
||||
self.map.insert(idx, val);
|
||||
self
|
||||
}
|
||||
/// Select
|
||||
pub fn select(&self, idx: &Value) -> Value {
|
||||
self.map.get(idx).unwrap_or(&*self.default).clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Value {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
match self {
|
||||
@@ -650,13 +692,21 @@ impl Display for Value {
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
Value::Array(_s, d, map, size) => {
|
||||
write!(f, "(map default:{} size:{} {:?})", d, size, map)
|
||||
}
|
||||
Value::Array(a) => write!(f, "{}", a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Array {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"(map default:{} size:{} {:?})",
|
||||
self.default, self.size, self.map
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::Eq for Value {}
|
||||
impl std::cmp::Ord for Value {
|
||||
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
|
||||
@@ -672,12 +722,7 @@ impl std::hash::Hash for Value {
|
||||
Value::Int(bv) => bv.hash(state),
|
||||
Value::Field(bv) => bv.hash(state),
|
||||
Value::Bool(bv) => bv.hash(state),
|
||||
Value::Array(s, d, a, size) => {
|
||||
s.hash(state);
|
||||
d.hash(state);
|
||||
a.hash(state);
|
||||
size.hash(state);
|
||||
}
|
||||
Value::Array(a) => a.hash(state),
|
||||
Value::Tuple(s) => {
|
||||
s.hash(state);
|
||||
}
|
||||
@@ -784,6 +829,37 @@ impl Sort {
|
||||
_ => panic!("Cannot iterate over {}", self),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the default term for this sort.
|
||||
///
|
||||
/// * booleans: false
|
||||
/// * bit-vectors: zero
|
||||
/// * field elements: zero
|
||||
/// * floats: zero
|
||||
/// * tuples/arrays: recursively default
|
||||
pub fn default_term(&self) -> Term {
|
||||
leaf_term(Op::Const(self.default_value()))
|
||||
}
|
||||
|
||||
/// Compute the default value for this sort.
|
||||
///
|
||||
/// * booleans: false
|
||||
/// * bit-vectors: zero
|
||||
/// * field elements: zero
|
||||
/// * floats: zero
|
||||
/// * tuples/arrays: recursively default
|
||||
pub fn default_value(&self) -> Value {
|
||||
match self {
|
||||
Sort::Bool => Value::Bool(false),
|
||||
Sort::BitVector(w) => Value::BitVector(BitVector::new(0.into(), *w)),
|
||||
Sort::Field(m) => Value::Field(FieldElem::new(Integer::from(0), m.clone())),
|
||||
Sort::Int => Value::Int(0.into()),
|
||||
Sort::F32 => Value::F32(0.0f32),
|
||||
Sort::F64 => Value::F64(0.0),
|
||||
Sort::Tuple(t) => Value::Tuple(t.iter().map(Sort::default_value).collect()),
|
||||
Sort::Array(k, v, n) => Value::Array(Array::default((**k).clone(), v, *n)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Sort {
|
||||
@@ -939,6 +1015,14 @@ impl TermData {
|
||||
false
|
||||
}
|
||||
}
|
||||
/// Is this a value
|
||||
pub fn is_const(&self) -> bool {
|
||||
if let Op::Const(..) = &self.op {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
@@ -951,7 +1035,12 @@ impl Value {
|
||||
Value::F64(_) => Sort::F64,
|
||||
Value::F32(_) => Sort::F32,
|
||||
Value::BitVector(b) => Sort::BitVector(b.width()),
|
||||
Value::Array(s, _, _, _) => s.clone(),
|
||||
Value::Array(Array {
|
||||
key_sort,
|
||||
default,
|
||||
size,
|
||||
..
|
||||
}) => Sort::Array(Box::new(key_sort.clone()), Box::new(default.sort()), *size),
|
||||
Value::Tuple(v) => Sort::Tuple(v.iter().map(Value::sort).collect()),
|
||||
}
|
||||
}
|
||||
@@ -992,6 +1081,16 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
/// Unwrap the constituent value of this array, panicking otherwise.
|
||||
pub fn as_array(&self) -> &Array {
|
||||
if let Value::Array(w) = self {
|
||||
&w
|
||||
} else {
|
||||
panic!("{} is not an aray", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the underlying boolean constant, if possible.
|
||||
pub fn as_bool_opt(&self) -> Option<bool> {
|
||||
if let Value::Bool(b) = self {
|
||||
@@ -1008,6 +1107,16 @@ impl Value {
|
||||
None
|
||||
}
|
||||
}
|
||||
/// Compute the sort of this value
|
||||
pub fn as_usize(&self) -> Option<usize> {
|
||||
match &self {
|
||||
Value::Bool(b) => Some(*b as usize),
|
||||
Value::Field(f) => f.i().to_usize(),
|
||||
Value::Int(i) => i.to_usize(),
|
||||
Value::BitVector(b) => b.uint().to_usize(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the term `t`, using variable values in `h`.
|
||||
@@ -1148,12 +1257,32 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_bv().clone();
|
||||
field::FieldElem::new(a.uint().clone(), m.clone())
|
||||
}),
|
||||
// tuple
|
||||
Op::Tuple => Value::Tuple(c.cs.iter().map(|c| vs.get(c).unwrap().clone()).collect()),
|
||||
Op::Field(i) => {
|
||||
let t = vs.get(&c.cs[0]).unwrap().as_tuple();
|
||||
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs[0]);
|
||||
t[*i].clone()
|
||||
}
|
||||
Op::Update(i) => {
|
||||
let mut t = vs.get(&c.cs[0]).unwrap().as_tuple().clone();
|
||||
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs[0]);
|
||||
let e = vs.get(&c.cs[1]).unwrap().clone();
|
||||
t[*i] = e;
|
||||
Value::Tuple(t)
|
||||
}
|
||||
// array
|
||||
Op::Store => {
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_array().clone();
|
||||
let i = vs.get(&c.cs[1]).unwrap().clone();
|
||||
let v = vs.get(&c.cs[2]).unwrap().clone();
|
||||
Value::Array(a.clone().store(i, v))
|
||||
}
|
||||
Op::Select => {
|
||||
let a = vs.get(&c.cs[0]).unwrap().as_array().clone();
|
||||
let i = vs.get(&c.cs[1]).unwrap();
|
||||
a.clone().select(i)
|
||||
}
|
||||
o => unimplemented!("eval: {:?}", o),
|
||||
};
|
||||
//println!("Eval {}\nAs {}", c, v);
|
||||
@@ -1162,14 +1291,31 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
|
||||
vs.get(t).unwrap().clone()
|
||||
}
|
||||
|
||||
/// Make an array from a sequence of terms.
|
||||
///
|
||||
/// Requires
|
||||
///
|
||||
/// * a key sort, as all arrays do. This sort must be iterable (i.e., bool, int, bit-vector, or field).
|
||||
/// * a value sort, for the array's default
|
||||
pub fn make_array(key_sort: Sort, value_sort: Sort, i: Vec<Term>) -> Term {
|
||||
let d = Sort::Array(Box::new(key_sort.clone()), Box::new(value_sort), i.len()).default_term();
|
||||
i.into_iter()
|
||||
.zip(key_sort.elems_iter())
|
||||
.fold(d, |arr, (val, idx)| term(Op::Store, vec![arr, idx, val]))
|
||||
}
|
||||
|
||||
/// Make a term with no arguments, just an operator.
|
||||
pub fn leaf_term(op: Op) -> Term {
|
||||
term(op, Vec::new())
|
||||
}
|
||||
|
||||
/// Make a term with arguments.
|
||||
#[track_caller]
|
||||
pub fn term(op: Op, cs: Vec<Term>) -> Term {
|
||||
mk(TermData { op, cs })
|
||||
let t = mk(TermData { op, cs });
|
||||
#[cfg(debug_assertions)]
|
||||
check_rec(&t);
|
||||
t
|
||||
}
|
||||
|
||||
/// Make a bit-vector constant term.
|
||||
@@ -1183,6 +1329,11 @@ where
|
||||
))))
|
||||
}
|
||||
|
||||
/// Make a bit-vector constant term.
|
||||
pub fn bool_lit(b: bool) -> Term {
|
||||
leaf_term(Op::Const(Value::Bool(b)))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Make a term.
|
||||
///
|
||||
@@ -1280,6 +1431,46 @@ impl ComputationMetadata {
|
||||
self.input_vis.insert(input_name.clone(), party);
|
||||
self.inputs.push(term);
|
||||
}
|
||||
/// Replace the `original` computation input with `new`, in the order given.
|
||||
///
|
||||
/// If the old input order was
|
||||
///
|
||||
/// w x y z x1 x2 x3
|
||||
///
|
||||
/// and `x` was replaced with `x1`, `x2`, `x3`, then the new input order is
|
||||
///
|
||||
/// w x1 x2 x3 y z
|
||||
///
|
||||
/// and other metadata associated with `x` is removed.
|
||||
///
|
||||
/// This is probably called after making the new inputs with `new_input`.
|
||||
pub fn replace_input(
|
||||
&mut self,
|
||||
original: Term,
|
||||
new: Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
|
||||
) {
|
||||
let mut i = self.inputs.iter().position(|t| t == &original).unwrap();
|
||||
self.inputs.remove(i);
|
||||
let name = if let Op::Var(n, _) = &original.op {
|
||||
n.to_string()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
self.input_vis.remove(&name).unwrap();
|
||||
for (input_name, sort, _, party) in new {
|
||||
let term = leaf_term(Op::Var(input_name.clone(), sort));
|
||||
debug_assert!(
|
||||
!self.input_vis.contains_key(&input_name),
|
||||
"Tried to create input {} (visibility {:?}), but it already existed (visibility {:?})",
|
||||
input_name,
|
||||
party,
|
||||
self.input_vis.get(&input_name).unwrap()
|
||||
);
|
||||
self.input_vis.insert(input_name.clone(), party);
|
||||
self.inputs.insert(i, term);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
/// Returns None if the value is public. Otherwise, the unique party that knows it.
|
||||
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
|
||||
self.input_vis
|
||||
@@ -1289,15 +1480,14 @@ impl ComputationMetadata {
|
||||
}
|
||||
/// Is this input public?
|
||||
pub fn is_input(&self, input_name: &str) -> bool {
|
||||
self.input_vis
|
||||
.contains_key(input_name)
|
||||
self.input_vis.contains_key(input_name)
|
||||
}
|
||||
/// Is this input public?
|
||||
pub fn is_input_public(&self, input_name: &str) -> bool {
|
||||
self.get_input_visibility(input_name).is_none()
|
||||
}
|
||||
/// Get all public inputs.
|
||||
pub fn public_inputs(&self) -> impl Iterator<Item = &str> {
|
||||
pub fn public_input_names(&self) -> impl Iterator<Item = &str> {
|
||||
self.input_vis.iter().filter_map(|(name, party)| {
|
||||
if party.is_none() {
|
||||
Some(name.as_str())
|
||||
@@ -1306,6 +1496,21 @@ impl ComputationMetadata {
|
||||
}
|
||||
})
|
||||
}
|
||||
/// Get all public inputs.
|
||||
pub fn public_inputs<'a>(&'a self) -> impl Iterator<Item = Term> + 'a {
|
||||
self.inputs.iter().filter_map(move |input| {
|
||||
if let Op::Var(name, _) = &input.op {
|
||||
let party = self.get_input_visibility(name);
|
||||
if party.is_none() {
|
||||
Some(input.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -1352,6 +1557,44 @@ impl Computation {
|
||||
}
|
||||
leaf_term(Op::Var(name.to_owned(), s))
|
||||
}
|
||||
/// Replace the `original` computation input with `new`, in the order given.
|
||||
///
|
||||
/// If the old input order was
|
||||
///
|
||||
/// w x y z
|
||||
///
|
||||
/// and `x` was replaced with `x1`, `x2`, `x3`, then the new input order is
|
||||
///
|
||||
/// w x1 x2 x3 y z
|
||||
///
|
||||
/// and other metadata associated with `x` is removed.
|
||||
///
|
||||
/// This is called in place of `new_var` during transformations.
|
||||
pub fn replace_input(
|
||||
&mut self,
|
||||
original: Term,
|
||||
mut new: Vec<(String, Sort, Option<Value>, Option<PartyId>)>,
|
||||
) {
|
||||
if let Some(vs) = self.values.as_mut() {
|
||||
if let Op::Var(name, _) = &original.op {
|
||||
vs.remove(name);
|
||||
for (name, _, val_opt, _) in &mut new {
|
||||
vs.insert(name.clone(), std::mem::take(val_opt).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
self.metadata.replace_input(original, new);
|
||||
}
|
||||
|
||||
/// Change the value associated with an input
|
||||
pub fn map_value(&mut self, name: &str, f: impl FnOnce(Value) -> Value) {
|
||||
if let Some(vs) = self.values.as_mut() {
|
||||
let loc = vs.get_mut(name).unwrap();
|
||||
let v = std::mem::replace(loc, Value::Bool(false));
|
||||
*loc = f(v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new variable, `name` in the constraint system, and set it equal to `term`.
|
||||
/// `public` indicates whether this variable is public in the constraint system.
|
||||
pub fn assign(&mut self, name: &str, term: Term, party: Option<PartyId>) -> Term {
|
||||
@@ -1364,7 +1607,7 @@ impl Computation {
|
||||
/// Assert `s` in the system.
|
||||
pub fn assert(&mut self, s: Term) {
|
||||
assert!(check(&s) == Sort::Bool);
|
||||
debug!("Assert: {}", s);
|
||||
debug!("Assert: {}", extras::Letified(s.clone()));
|
||||
self.outputs.push(s);
|
||||
}
|
||||
/// If tracking values, evaluate `term`, and set the result to `name`.
|
||||
@@ -1383,41 +1626,21 @@ impl Computation {
|
||||
Self {
|
||||
outputs: Vec::new(),
|
||||
metadata: ComputationMetadata::default(),
|
||||
values: if values { Some(FxHashMap::default()) } else { None },
|
||||
values: if values {
|
||||
Some(FxHashMap::default())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
// TODO: rm
|
||||
// /// Make `s` a public input.
|
||||
// pub fn publicize(&mut self, s: String) {
|
||||
// self.public_inputs.insert(s);
|
||||
// }
|
||||
|
||||
/// Get the outputs of the computation.
|
||||
///
|
||||
/// For proof systems, these are the assertions that must hold.
|
||||
pub fn outputs(&self) -> &Vec<Term> {
|
||||
&self.outputs
|
||||
}
|
||||
// TODO: rm
|
||||
// /// Consume this system, yielding its parts: (assertions, public inputs, values)
|
||||
// pub fn consume(self) -> (Vec<Term>, ComputationMetadata, Option<FxHashMap<String, Value>>) {
|
||||
// (self.assertions, self.metadata, self.values)
|
||||
// }
|
||||
// /// Build a system from its parts: (assertions, public inputs, values)
|
||||
// pub fn from_parts(
|
||||
// assertions: Vec<Term>,
|
||||
// public_inputs: FxHashSet<String>,
|
||||
// values: Option<FxHashMap<String, Value>>,
|
||||
// ) -> Self {
|
||||
// Self {
|
||||
// assertions,
|
||||
// public_inputs,
|
||||
// values,
|
||||
// }
|
||||
// }
|
||||
// /// Get the term, (AND all assertions)
|
||||
// pub fn assertions_as_term(&self) -> Term {
|
||||
// term(Op::BoolNaryOp(BoolNaryOp::And), self.assertions.clone())
|
||||
// }
|
||||
|
||||
/// How many total (unique) terms are there?
|
||||
pub fn terms(&self) -> usize {
|
||||
let mut terms = FxHashSet::<Term>::default();
|
||||
@@ -1428,6 +1651,14 @@ impl Computation {
|
||||
}
|
||||
terms.len()
|
||||
}
|
||||
|
||||
/// An iterator that visits each term in the computation, once.
|
||||
pub fn terms_postorder(&self) -> impl Iterator<Item = Term> {
|
||||
let mut terms: Vec<_> = PostOrderIter::new(term(Op::Tuple, self.outputs.clone())).collect();
|
||||
// drop the top-level tuple term.
|
||||
terms.pop();
|
||||
terms.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -85,11 +85,6 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
Op::PfUnOp(_) => Ok(check_raw(&t.cs[0])?),
|
||||
Op::PfNaryOp(_) => Ok(check_raw(&t.cs[0])?),
|
||||
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
|
||||
Op::ConstArray(s, n) => Ok(Sort::Array(
|
||||
Box::new(s.clone()),
|
||||
Box::new(check_raw(&t.cs[0])?),
|
||||
*n,
|
||||
)),
|
||||
Op::Select => array_or(&check_raw(&t.cs[0])?, "select").map(|(_, v)| v.clone()),
|
||||
Op::Store => Ok(check_raw(&t.cs[0])?),
|
||||
Op::Tuple => Ok(Sort::Tuple(
|
||||
@@ -109,6 +104,7 @@ pub fn check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
)))
|
||||
}
|
||||
}
|
||||
Op::Update(_i) => Ok(check_raw(&t.cs[0])?),
|
||||
o => Err(TypeErrorReason::Custom(format!("other operator: {}", o))),
|
||||
};
|
||||
let mut term_tys = TERM_TYPES.write().unwrap();
|
||||
@@ -260,13 +256,8 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
.and_then(|t| pf_or(t, ctx))
|
||||
.map(|a| a.clone())
|
||||
}
|
||||
(Op::UbvToPf(m), &[a]) => {
|
||||
bv_or(a, "sbv-to-fp").map(|_| Sort::Field(m.clone()))
|
||||
}
|
||||
(Op::UbvToPf(m), &[a]) => bv_or(a, "sbv-to-fp").map(|_| Sort::Field(m.clone())),
|
||||
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
|
||||
(Op::ConstArray(s, n), &[a]) => {
|
||||
Ok(Sort::Array(Box::new(s.clone()), Box::new(a.clone()), *n))
|
||||
}
|
||||
(Op::Select, &[Sort::Array(k, v, _), a]) => {
|
||||
eq_or(k, a, "select").map(|_| (**v).clone())
|
||||
}
|
||||
@@ -286,6 +277,17 @@ pub fn rec_check_raw(t: &Term) -> Result<Sort, TypeError> {
|
||||
)))
|
||||
}
|
||||
}),
|
||||
(Op::Update(i), &[a, b]) => tuple_or(a, "tuple field update").and_then(|t| {
|
||||
if i < &t.len() {
|
||||
eq_or(&t[*i], b, "tuple update")?;
|
||||
Ok(a.clone())
|
||||
} else {
|
||||
Err(TypeErrorReason::OutOfBounds(format!(
|
||||
"index {} in tuple of sort {}",
|
||||
i, a
|
||||
)))
|
||||
}
|
||||
}),
|
||||
(_, _) => Err(TypeErrorReason::Custom(format!("other"))),
|
||||
})
|
||||
.map_err(|reason| TypeError {
|
||||
@@ -328,6 +330,8 @@ pub enum TypeErrorReason {
|
||||
ExpectedPf(Sort, &'static str),
|
||||
/// A sort should be an array
|
||||
ExpectedArray(Sort, &'static str),
|
||||
/// A sort should be a tuple
|
||||
ExpectedTuple(&'static str),
|
||||
/// An empty n-ary operator.
|
||||
EmptyNary(String),
|
||||
/// Something else
|
||||
@@ -377,7 +381,7 @@ fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
|
||||
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Vec<Sort>, TypeErrorReason> {
|
||||
match a {
|
||||
Sort::Tuple(a) => Ok(a),
|
||||
_ => Err(TypeErrorReason::ExpectedPf(a.clone(), ctx)),
|
||||
_ => Err(TypeErrorReason::ExpectedTuple(ctx)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,9 +150,11 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
|
||||
}
|
||||
}
|
||||
}
|
||||
let terms: FxHashMap<Term, usize> = terms.into_iter().enumerate().map(|(i, t)| (t, i)).collect();
|
||||
let terms: FxHashMap<Term, usize> =
|
||||
terms.into_iter().enumerate().map(|(i, t)| (t, i)).collect();
|
||||
let mut term_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default();
|
||||
let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = FxHashMap::default();
|
||||
let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> =
|
||||
FxHashMap::default();
|
||||
let mut ilp = Ilp::new();
|
||||
|
||||
// build variables for all term assignments
|
||||
|
||||
@@ -188,7 +188,13 @@ impl ToABY {
|
||||
|
||||
fn remove_cons_gate(&self, circ: String) -> String {
|
||||
if circ.contains("PutCONSGate(") {
|
||||
circ.split("PutCONSGate(").last().unwrap_or("").split(",").next().unwrap_or("").to_string()
|
||||
circ.split("PutCONSGate(")
|
||||
.last()
|
||||
.unwrap_or("")
|
||||
.split(",")
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
panic!("PutCONSGate not found in: {}", circ)
|
||||
}
|
||||
@@ -510,9 +516,7 @@ impl ToABY {
|
||||
_ => panic!("Invalid bv-op in BvBinOp: {:?}", o),
|
||||
}
|
||||
}
|
||||
Op::BvExtract(_start, _end) => {
|
||||
|
||||
}
|
||||
Op::BvExtract(_start, _end) => {}
|
||||
_ => panic!("Non-field in embed_bv: {:?}", t),
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ use ff::{PrimeField, PrimeFieldBits};
|
||||
use gmp_mpfr_sys::gmp::limb_t;
|
||||
use log::debug;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::fs::File;
|
||||
use std::str::FromStr;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -32,9 +32,15 @@ fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
|
||||
zero_lc: LinearCombination<F>,
|
||||
) -> LinearCombination<F> {
|
||||
let mut lc_bellman = zero_lc;
|
||||
lc_bellman = lc_bellman + (int_to_ff(&lc.constant), CS::one());
|
||||
// This zero test is needed until https://github.com/zkcrypto/bellman/pull/78 is resolved
|
||||
if lc.constant != 0 {
|
||||
lc_bellman = lc_bellman + (int_to_ff(&lc.constant), CS::one());
|
||||
}
|
||||
for (v, c) in &lc.monomials {
|
||||
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
|
||||
// ditto
|
||||
if c != &0 {
|
||||
lc_bellman = lc_bellman + (int_to_ff(c), vars.get(v).unwrap().clone());
|
||||
}
|
||||
}
|
||||
lc_bellman
|
||||
}
|
||||
@@ -87,7 +93,7 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
|
||||
.get(&i)
|
||||
.unwrap();
|
||||
let ff_val = int_to_ff(i_val);
|
||||
debug!("witness: {} -> {:?} ({})", s, ff_val, i_val);
|
||||
debug!("value : {} -> {:?} ({})", s, ff_val, i_val);
|
||||
ff_val
|
||||
})
|
||||
};
|
||||
@@ -112,6 +118,11 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
|
||||
|z| lc_to_bellman::<F, CS>(&vars, c, z),
|
||||
);
|
||||
}
|
||||
debug!(
|
||||
"done with synth: {} vars {} cs",
|
||||
vars.len(),
|
||||
self.constraints.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -119,11 +130,13 @@ impl<'a, F: PrimeField + PrimeFieldBits, S: Display + Eq + Hash + Ord> Circuit<F
|
||||
/// Convert a (rug) integer to a prime field element.
|
||||
pub fn parse_instance<P: AsRef<Path>, F: PrimeField>(path: P) -> Vec<F> {
|
||||
let f = BufReader::new(File::open(path).unwrap());
|
||||
f.lines().map(|line| {
|
||||
let s = line.unwrap();
|
||||
let i = Integer::from_str(&s.trim()).unwrap();
|
||||
int_to_ff(&i)
|
||||
}).collect()
|
||||
f.lines()
|
||||
.map(|line| {
|
||||
let s = line.unwrap();
|
||||
let i = Integer::from_str(&s.trim()).unwrap();
|
||||
int_to_ff(&i)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -251,7 +251,11 @@ impl<S: Clone + Hash + Eq + Display> R1cs<S> {
|
||||
idxs_signals: HashMap::default(),
|
||||
next_idx: 0,
|
||||
public_idxs: HashSet::default(),
|
||||
values: if values { Some(HashMap::default()) } else { None },
|
||||
values: if values {
|
||||
Some(HashMap::default())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
constraints: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,8 +58,13 @@ impl ToR1cs {
|
||||
|
||||
/// Get a new variable, with name dependent on `d`.
|
||||
/// If values are being recorded, `value` must be provided.
|
||||
fn fresh_var<D: Display + ?Sized>(&mut self, ctx: &D, value: Option<Integer>, public: bool) -> Lc {
|
||||
let n = format!("{}_v{}", ctx, self.next_idx);
|
||||
fn fresh_var<D: Display + ?Sized>(
|
||||
&mut self,
|
||||
ctx: &D,
|
||||
value: Option<Integer>,
|
||||
public: bool,
|
||||
) -> Lc {
|
||||
let n = format!("{}_n{}", ctx, self.next_idx);
|
||||
self.next_idx += 1;
|
||||
self.r1cs.add_signal(n.clone(), value);
|
||||
if public {
|
||||
@@ -97,7 +102,11 @@ impl ToR1cs {
|
||||
}),
|
||||
false,
|
||||
);
|
||||
let is_zero = self.fresh_var("is_zero", self.r1cs.eval(&x).map(|x| Integer::from(x == 0)), false);
|
||||
let is_zero = self.fresh_var(
|
||||
"is_zero",
|
||||
self.r1cs.eval(&x).map(|x| Integer::from(x == 0)),
|
||||
false,
|
||||
);
|
||||
self.r1cs.constraint(m, x.clone(), -is_zero.clone() + 1);
|
||||
self.r1cs.constraint(is_zero.clone(), x, self.r1cs.zero());
|
||||
is_zero
|
||||
@@ -127,12 +136,15 @@ impl ToR1cs {
|
||||
/// Evaluate `var`'s value as an (integer-casted) bit-vector.
|
||||
/// Returns `None` if values are not stored.
|
||||
fn eval_bv(&self, var: &str) -> Option<Integer> {
|
||||
self.values
|
||||
.as_ref()
|
||||
.map(|vs| match vs.get(var).expect("missing value") {
|
||||
self.values.as_ref().map(|vs| {
|
||||
match vs
|
||||
.get(var)
|
||||
.unwrap_or_else(|| panic!("missing value for {}", var))
|
||||
{
|
||||
Value::BitVector(b) => b.uint().clone(),
|
||||
v => panic!("{} should be a bit-vector, but is {:?}", var, v),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluate `var`'s value as an (integer-casted) field element
|
||||
@@ -343,6 +355,27 @@ impl ToR1cs {
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_eq(&mut self, a: &Term, b: &Term) {
|
||||
match check(a) {
|
||||
Sort::Bool => {
|
||||
let a = self.get_bool(a).clone();
|
||||
let diff = a - self.get_bool(b);
|
||||
self.assert_zero(diff);
|
||||
}
|
||||
Sort::BitVector(_) => {
|
||||
let a = self.get_bv_uint(a);
|
||||
let diff = a - &self.get_bv_uint(b);
|
||||
self.assert_zero(diff);
|
||||
}
|
||||
Sort::Field(_) => {
|
||||
let a = self.get_pf(a).clone();
|
||||
let diff = a - self.get_pf(b);
|
||||
self.assert_zero(diff);
|
||||
}
|
||||
s => panic!("Unimplemented sort for Eq: {:?}", s),
|
||||
}
|
||||
}
|
||||
|
||||
fn embed_bool(&mut self, c: Term) -> &Lc {
|
||||
//println!("Embed: {}", c);
|
||||
debug_assert!(check(&c) == Sort::Bool);
|
||||
@@ -427,6 +460,23 @@ impl ToR1cs {
|
||||
self.get_bool(&c)
|
||||
}
|
||||
|
||||
fn assert_bool(&mut self, t: &Term) {
|
||||
//println!("Embed: {}", c);
|
||||
// TODO: skip if already embedded
|
||||
if &t.op == &Op::Eq {
|
||||
t.cs.iter().for_each(|c| self.embed(c.clone()));
|
||||
self.assert_eq(&t.cs[0], &t.cs[1]);
|
||||
} else if &t.op == &AND {
|
||||
for c in &t.cs {
|
||||
self.assert_bool(c);
|
||||
}
|
||||
} else {
|
||||
self.embed(t.clone());
|
||||
let lc = self.get_bool(&t).clone();
|
||||
self.assert_zero(lc - 1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether `a - b` fits in `size` non-negative bits.
|
||||
/// i.e. is in `{0, 1, ..., 2^n-1}`.
|
||||
fn bv_ge(&mut self, a: Lc, b: &Lc, size: usize) -> Lc {
|
||||
@@ -848,9 +898,8 @@ impl ToR1cs {
|
||||
}
|
||||
fn assert(&mut self, t: Term) {
|
||||
debug!("Assert: {}", Letified(t.clone()));
|
||||
self.embed(t.clone());
|
||||
let lc = self.get_bool(&t).clone();
|
||||
self.assert_zero(lc - 1);
|
||||
debug_assert!(check(&t) == Sort::Bool, "Non bool in assert");
|
||||
self.assert_bool(&t);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -861,7 +910,10 @@ pub fn to_r1cs(cs: Computation, modulus: Integer) -> R1cs<String> {
|
||||
metadata,
|
||||
values,
|
||||
} = cs;
|
||||
let public_inputs = metadata.public_inputs().map(ToOwned::to_owned).collect();
|
||||
let public_inputs = metadata
|
||||
.public_input_names()
|
||||
.map(ToOwned::to_owned)
|
||||
.collect();
|
||||
debug!("public inputs: {:?}", public_inputs);
|
||||
let mut converter = ToR1cs::new(modulus, values, public_inputs);
|
||||
debug!(
|
||||
@@ -871,7 +923,12 @@ pub fn to_r1cs(cs: Computation, modulus: Integer) -> R1cs<String> {
|
||||
.map(|c| PostOrderIter::new(c.clone()).count())
|
||||
.sum::<usize>()
|
||||
);
|
||||
println!("Printing assertions");
|
||||
debug!("declaring inputs");
|
||||
for i in metadata.public_inputs() {
|
||||
debug!("input {}", i);
|
||||
converter.embed(i);
|
||||
}
|
||||
debug!("Printing assertions");
|
||||
for c in assertions {
|
||||
converter.assert(c);
|
||||
}
|
||||
@@ -969,8 +1026,7 @@ pub mod test {
|
||||
} else {
|
||||
term![Op::Not; t]
|
||||
};
|
||||
let cs =
|
||||
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
|
||||
r1cs.check_all();
|
||||
}
|
||||
@@ -979,9 +1035,9 @@ pub mod test {
|
||||
fn random_bool(ArbitraryTermEnv(t, values): ArbitraryTermEnv) {
|
||||
let v = eval(&t, &values);
|
||||
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
|
||||
let cs =
|
||||
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
|
||||
let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
|
||||
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
|
||||
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
|
||||
r1cs.check_all();
|
||||
}
|
||||
@@ -990,8 +1046,7 @@ pub mod test {
|
||||
fn random_pure_bool_opt(ArbitraryBoolEnv(t, values): ArbitraryBoolEnv) {
|
||||
let v = eval(&t, &values);
|
||||
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
|
||||
let cs =
|
||||
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
|
||||
r1cs.check_all();
|
||||
let r1cs2 = reduce_linearities(r1cs);
|
||||
@@ -1002,9 +1057,9 @@ pub mod test {
|
||||
fn random_bool_opt(ArbitraryTermEnv(t, values): ArbitraryTermEnv) {
|
||||
let v = eval(&t, &values);
|
||||
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
|
||||
let cs =
|
||||
Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
|
||||
let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new(), Some(values));
|
||||
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
|
||||
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
|
||||
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
|
||||
r1cs.check_all();
|
||||
let r1cs2 = reduce_linearities(r1cs);
|
||||
@@ -1016,9 +1071,7 @@ pub mod test {
|
||||
let cs = Computation::from_constraint_system_parts(
|
||||
vec![term![Op::Not; term![Op::Eq; bv(0b10110, 8),
|
||||
term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]]],
|
||||
vec![
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(8))),
|
||||
],
|
||||
vec![leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))],
|
||||
Some(
|
||||
vec![(
|
||||
"b".to_owned(),
|
||||
@@ -1041,8 +1094,7 @@ pub mod test {
|
||||
.collect();
|
||||
let v = eval(&t, &values);
|
||||
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
|
||||
let cs =
|
||||
Computation::from_constraint_system_parts(vec![t], vec![], Some(values));
|
||||
let cs = Computation::from_constraint_system_parts(vec![t], vec![], Some(values));
|
||||
let r1cs = to_r1cs(cs, Integer::from(crate::ir::term::field::TEST_FIELD));
|
||||
r1cs.check_all();
|
||||
let r1cs2 = reduce_linearities(r1cs);
|
||||
@@ -1170,7 +1222,7 @@ pub mod test {
|
||||
|
||||
#[test]
|
||||
fn tuple() {
|
||||
let cs = Computation::from_constraint_system_parts(
|
||||
let mut cs = Computation::from_constraint_system_parts(
|
||||
vec![
|
||||
term![Op::Field(0); term![Op::Tuple; leaf_term(Op::Var("a".to_owned(), Sort::Bool)), leaf_term(Op::Const(Value::Bool(false)))]],
|
||||
term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))],
|
||||
@@ -1188,7 +1240,7 @@ pub mod test {
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
let cs = crate::ir::opt::tuple::eliminate_tuples(cs);
|
||||
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
|
||||
let r1cs = to_r1cs(cs, Integer::from(17));
|
||||
r1cs.check_all();
|
||||
}
|
||||
|
||||
@@ -71,14 +71,21 @@ impl Expr2Smt<()> for Value {
|
||||
}
|
||||
write!(w, ")")?;
|
||||
}
|
||||
Value::Array(s, default, map, _size) => {
|
||||
Value::Array(Array {
|
||||
key_sort,
|
||||
default,
|
||||
map,
|
||||
size,
|
||||
}) => {
|
||||
for _ in 0..map.len() {
|
||||
write!(w, "(store ")?;
|
||||
}
|
||||
let val_s = check(&leaf_term(Op::Const((**default).clone())));
|
||||
let s = Sort::Array(Box::new(key_sort.clone()), Box::new(val_s), *size);
|
||||
write!(
|
||||
w,
|
||||
"((as const {}) {})",
|
||||
SmtSortDisp(&*s),
|
||||
SmtSortDisp(&s),
|
||||
SmtDisp(&**default)
|
||||
)?;
|
||||
for (k, v) in map {
|
||||
|
||||
@@ -23,7 +23,6 @@ mod hash_test {
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_map_non_det_iter_order() {
|
||||
let mut m1: HashMap<usize, usize> = HashMap::new();
|
||||
|
||||
Reference in New Issue
Block a user