mirror of
https://github.com/circify/circ.git
synced 2026-01-10 06:08:02 -05:00
Add extension operators and new operators (#150)
modifies some opt passes w/ new ops
This commit is contained in:
44
Cargo.lock
generated
44
Cargo.lock
generated
@@ -291,6 +291,7 @@ dependencies = [
|
||||
"rand_chacha 0.3.1",
|
||||
"rsmt2",
|
||||
"rug",
|
||||
"rug-polynomial",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
"serde_json",
|
||||
@@ -655,6 +656,17 @@ dependencies = [
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flint-sys"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd39a21fba2a1dd249af52d118f5b6cc8f1a98c87732de386dc5e8e53d22a24b"
|
||||
dependencies = [
|
||||
"gmp-mpfr-sys",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@@ -735,12 +747,12 @@ checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
|
||||
|
||||
[[package]]
|
||||
name = "gmp-mpfr-sys"
|
||||
version = "1.5.1"
|
||||
version = "1.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a32868092c26fb25bb33c5ca7d8a937647979dfaa12f1f4f464beb57d726662c"
|
||||
checksum = "781d258743e6e07d38c5bcd7468cfaa068d56789d10e8de0d3836d5bb338b5d7"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.42.0",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1403,9 +1415,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rug"
|
||||
version = "1.19.1"
|
||||
version = "1.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a465f6576b9f0844bd35749197576d68e3db169120532c4e0f868ecccad3d449"
|
||||
checksum = "55313a5bab6820d1439c0266db37f8084e50618d35d16d81f692eee14d8b01b9"
|
||||
dependencies = [
|
||||
"az",
|
||||
"gmp-mpfr-sys",
|
||||
@@ -1413,6 +1425,28 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rug-fft"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "333f0e04ec25ccaee0885a94377702d00d6fa631023464ff38bec2efe7247d90"
|
||||
dependencies = [
|
||||
"rug",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rug-polynomial"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e78837d2c84bd38a3c9e666f802e21837df9f46d8413f4787a823c9b448d392d"
|
||||
dependencies = [
|
||||
"flint-sys",
|
||||
"gmp-mpfr-sys",
|
||||
"rug",
|
||||
"rug-fft",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.21"
|
||||
|
||||
@@ -23,6 +23,7 @@ typed-arena = { version = "2.0", optional = true }
|
||||
log = "0.4"
|
||||
thiserror = "1.0"
|
||||
bellman = { git = "https://github.com/alex-ozdemir/bellman.git", branch = "mirage", optional = true }
|
||||
rug-polynomial = { version = "0.2.5", optional = true }
|
||||
ff = { version = "0.12", optional = true }
|
||||
fxhash = "0.2"
|
||||
good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false, optional = true }
|
||||
@@ -70,6 +71,7 @@ aby = ["lp"]
|
||||
kahip = ["aby"]
|
||||
kahypar = ["aby"]
|
||||
r1cs = []
|
||||
poly = ["rug-polynomial"]
|
||||
spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"]
|
||||
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys", "byteorder"]
|
||||
|
||||
|
||||
@@ -27,8 +27,7 @@ fn main() {
|
||||
}
|
||||
let cs = Computation {
|
||||
outputs: vec![term![Op::Eq; t, v]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
precomputes: Default::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let _assignment = ilp::assign(&cs, "hycc");
|
||||
//dbg!(&assignment);
|
||||
|
||||
@@ -310,6 +310,26 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
Op::Array(k, v) => t
|
||||
.cs()
|
||||
.iter()
|
||||
.map(|c| c_get(c).as_value_opt().cloned())
|
||||
.collect::<Option<_>>()
|
||||
.map(|cs| {
|
||||
leaf_term(Op::Const(Value::Array(Array::from_vec(
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
cs,
|
||||
))))
|
||||
}),
|
||||
Op::Fill(k, s) => c_get(&t.cs()[0]).as_value_opt().map(|v| {
|
||||
leaf_term(Op::Const(Value::Array(Array::new(
|
||||
k.clone(),
|
||||
Box::new(v.clone()),
|
||||
Default::default(),
|
||||
*s,
|
||||
))))
|
||||
}),
|
||||
Op::Select => match (get(0).as_array_opt(), get(1).as_value_opt()) {
|
||||
(Some(arr), Some(idx)) => Some(leaf_term(Op::Const(arr.select(idx)))),
|
||||
_ => None,
|
||||
|
||||
@@ -37,6 +37,11 @@ impl RewritePass for Linearizer {
|
||||
computation.extend_precomputation(new_name.clone(), precomp);
|
||||
Some(leaf_term(Op::Var(new_name, new_sort)))
|
||||
}
|
||||
Op::Array(..) => Some(term(Op::Tuple, rewritten_children())),
|
||||
Op::Fill(_, size) => Some(term(
|
||||
Op::Tuple,
|
||||
vec![rewritten_children().pop().unwrap(); *size],
|
||||
)),
|
||||
Op::Select => {
|
||||
let cs = rewritten_children();
|
||||
let idx = &cs[1];
|
||||
|
||||
@@ -117,7 +117,7 @@ impl NonOblivComputer {
|
||||
impl ProgressAnalysisPass for NonOblivComputer {
|
||||
fn visit(&mut self, term: &Term) -> bool {
|
||||
match &term.op() {
|
||||
Op::Store => {
|
||||
Op::Store | Op::CStore => {
|
||||
let a = &term.cs()[0];
|
||||
let i = &term.cs()[1];
|
||||
let v = &term.cs()[2];
|
||||
@@ -138,6 +138,32 @@ impl ProgressAnalysisPass for NonOblivComputer {
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Array(..) => {
|
||||
let mut progress = false;
|
||||
if !term.cs().is_empty() {
|
||||
if let Sort::Array(..) = check(&term.cs()[0]) {
|
||||
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
|
||||
for i in 0..term.cs().len() - 1 {
|
||||
progress =
|
||||
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
|
||||
}
|
||||
for i in (0..term.cs().len() - 1).rev() {
|
||||
progress =
|
||||
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
|
||||
}
|
||||
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
|
||||
}
|
||||
}
|
||||
progress
|
||||
}
|
||||
Op::Fill(..) => {
|
||||
let v = &term.cs()[0];
|
||||
if let Sort::Array(..) = check(v) {
|
||||
self.bi_implicate(term, &term.cs()[0])
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Op::Select => {
|
||||
// Even though the selected value may not have array sort, we still flag it as
|
||||
// non-oblivious so we know whether to replace it or not.
|
||||
@@ -153,6 +179,13 @@ impl ProgressAnalysisPass for NonOblivComputer {
|
||||
progress = self.bi_implicate(term, a) || progress;
|
||||
progress
|
||||
}
|
||||
Op::Var(..) => {
|
||||
if let Sort::Array(..) = check(term) {
|
||||
self.mark(term)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
let t = &term.cs()[1];
|
||||
let f = &term.cs()[2];
|
||||
@@ -269,6 +302,32 @@ impl RewritePass for Replacer {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::CStore => {
|
||||
if self.should_replace(orig) {
|
||||
let mut cs = get_cs();
|
||||
debug_assert_eq!(cs.len(), 4);
|
||||
let cond = cs.remove(3);
|
||||
let k_const = get_const(&cs.remove(1));
|
||||
let orig = cs[0].clone();
|
||||
Some(term![ITE; cond, term(Op::Update(k_const), cs), orig])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Array(..) => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Tuple, get_cs()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Fill(_, size) => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Tuple, vec![get_cs().pop().unwrap(); *size]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Op::Ite => {
|
||||
if self.should_replace(orig) {
|
||||
Some(term(Op::Ite, get_cs()))
|
||||
|
||||
@@ -68,12 +68,18 @@ impl RewritePass for Pass {
|
||||
_rewritten_children: F,
|
||||
) -> Option<Term> {
|
||||
if let Op::Var(name, sort) = &orig.op() {
|
||||
let mut new_var_reqs = Vec::new();
|
||||
let new = create_vars(name, orig.clone(), sort, &mut new_var_reqs, true);
|
||||
for (name, term) in new_var_reqs {
|
||||
computation.extend_precomputation(name, term);
|
||||
debug!("Considering var: {}", name);
|
||||
if !computation.metadata.lookup(name).committed {
|
||||
let mut new_var_reqs = Vec::new();
|
||||
let new = create_vars(name, orig.clone(), sort, &mut new_var_reqs, true);
|
||||
for (name, term) in new_var_reqs {
|
||||
computation.extend_precomputation(name, term);
|
||||
}
|
||||
Some(new)
|
||||
} else {
|
||||
debug!("Skipping b/c it is commited.");
|
||||
None
|
||||
}
|
||||
Some(new)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -89,25 +95,28 @@ pub fn scalarize_inputs(cs: &mut Computation) {
|
||||
remove_non_scalar_vars_from_main_computation(cs);
|
||||
}
|
||||
|
||||
/// Check that every variables is a scalar.
|
||||
/// Check that every variables is a scalar (or committed)
|
||||
pub fn assert_all_vars_are_scalars(cs: &Computation) {
|
||||
for t in cs.terms_postorder() {
|
||||
if let Op::Var(_name, sort) = &t.op() {
|
||||
match sort {
|
||||
Sort::Array(..) | Sort::Tuple(..) => {
|
||||
panic!("Variable {} is non-scalar", t);
|
||||
if let Op::Var(name, sort) = &t.op() {
|
||||
if !cs.metadata.lookup(name).committed {
|
||||
match sort {
|
||||
Sort::Array(..) | Sort::Tuple(..) => {
|
||||
panic!("Variable {} is non-scalar", t);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that every variables is a scalar.
|
||||
/// Remove all variables that are non-scalar (and not committed).
|
||||
fn remove_non_scalar_vars_from_main_computation(cs: &mut Computation) {
|
||||
for input in cs.metadata.ordered_inputs() {
|
||||
if !check(&input).is_scalar() {
|
||||
cs.metadata.remove_var(input.as_var_name());
|
||||
let name = input.as_var_name();
|
||||
if !check(&input).is_scalar() && !cs.metadata.lookup(&name).committed {
|
||||
cs.metadata.remove_var(name);
|
||||
}
|
||||
}
|
||||
assert_all_vars_are_scalars(cs);
|
||||
|
||||
@@ -106,6 +106,17 @@ impl TupleTree {
|
||||
fn bimap(&self, mut f: impl FnMut(Term, Term) -> Term, other: &Self) -> Self {
|
||||
self.structure(itertools::zip_eq(self.flatten(), other.flatten()).map(|(a, b)| f(a, b)))
|
||||
}
|
||||
fn transpose_map(vs: Vec<Self>, f: impl FnMut(Vec<Term>) -> Term) -> Self {
|
||||
assert!(!vs.is_empty());
|
||||
let n = vs[0].flatten().count();
|
||||
let mut ts = vec![Vec::new(); n];
|
||||
for v in &vs {
|
||||
for (i, t) in v.flatten().enumerate() {
|
||||
ts[i].push(t);
|
||||
}
|
||||
}
|
||||
vs[0].structure(ts.into_iter().map(f))
|
||||
}
|
||||
fn get(&self, i: usize) -> Self {
|
||||
match self {
|
||||
TupleTree::NonTuple(cs) => {
|
||||
@@ -250,6 +261,15 @@ pub fn eliminate_tuples(cs: &mut Computation) {
|
||||
debug_assert!(cs.is_empty());
|
||||
a.bimap(|a, v| term![Op::Store; a, i.clone(), v], &v)
|
||||
}
|
||||
Op::Array(k, _v) => TupleTree::transpose_map(cs, |children| {
|
||||
assert!(!children.is_empty());
|
||||
let v_s = check(&children[0]);
|
||||
term(Op::Array(k.clone(), v_s), children)
|
||||
}),
|
||||
Op::Fill(key_sort, size) => {
|
||||
let values = cs.pop().unwrap();
|
||||
values.map(|v| term![Op::Fill(key_sort.clone(), *size); v])
|
||||
}
|
||||
Op::Select => {
|
||||
let i = cs.pop().unwrap().unwrap_non_tuple();
|
||||
let a = cs.pop().unwrap();
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::trace;
|
||||
|
||||
/// A rewriting pass.
|
||||
pub trait RewritePass {
|
||||
/// Visit (and possibly rewrite) a term.
|
||||
@@ -11,11 +13,49 @@ pub trait RewritePass {
|
||||
orig: &Term,
|
||||
rewritten_children: F,
|
||||
) -> Option<Term>;
|
||||
|
||||
/// What
|
||||
fn visit_cache(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
orig: &Term,
|
||||
cache: &TermMap<Term>,
|
||||
) -> Option<Term> {
|
||||
let get_children = || -> Vec<Term> {
|
||||
orig.cs()
|
||||
.iter()
|
||||
.map(|c| cache.get(c).unwrap())
|
||||
.cloned()
|
||||
.collect()
|
||||
};
|
||||
self.visit(computation, orig, get_children)
|
||||
}
|
||||
|
||||
fn traverse(&mut self, computation: &mut Computation) {
|
||||
self.traverse_full(computation, false, true);
|
||||
}
|
||||
|
||||
fn traverse_full(
|
||||
&mut self,
|
||||
computation: &mut Computation,
|
||||
precompute: bool,
|
||||
persistent_arrays: bool,
|
||||
) {
|
||||
let mut cache = TermMap::<Term>::default();
|
||||
let mut children_added = TermSet::default();
|
||||
let mut stack = Vec::new();
|
||||
if persistent_arrays {
|
||||
stack.extend(
|
||||
computation
|
||||
.persistent_arrays
|
||||
.iter()
|
||||
.map(|(_name, final_term)| final_term.clone()),
|
||||
);
|
||||
}
|
||||
stack.extend(computation.outputs.iter().cloned());
|
||||
if precompute {
|
||||
stack.extend(computation.precomputes.outputs().values().cloned());
|
||||
}
|
||||
while let Some(top) = stack.pop() {
|
||||
if !cache.contains_key(&top) {
|
||||
// was it missing?
|
||||
@@ -23,24 +63,40 @@ pub trait RewritePass {
|
||||
stack.push(top.clone());
|
||||
stack.extend(top.cs().iter().filter(|c| !cache.contains_key(c)).cloned());
|
||||
} else {
|
||||
let get_children = || -> Vec<Term> {
|
||||
top.cs()
|
||||
.iter()
|
||||
.map(|c| cache.get(c).unwrap())
|
||||
.cloned()
|
||||
.collect()
|
||||
};
|
||||
let new_t_opt = self.visit(computation, &top, get_children);
|
||||
let new_t = new_t_opt.unwrap_or_else(|| term(top.op().clone(), get_children()));
|
||||
let new_t_opt = self.visit_cache(computation, &top, &cache);
|
||||
let new_t = new_t_opt.unwrap_or_else(|| {
|
||||
term(
|
||||
top.op().clone(),
|
||||
top.cs()
|
||||
.iter()
|
||||
.map(|c| cache.get(c).unwrap())
|
||||
.cloned()
|
||||
.collect(),
|
||||
)
|
||||
});
|
||||
cache.insert(top.clone(), new_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
if persistent_arrays {
|
||||
for (_name, final_term) in &mut computation.persistent_arrays {
|
||||
let new_final_term = cache.get(final_term).unwrap().clone();
|
||||
trace!("Array {} -> {}", final_term, new_final_term);
|
||||
*final_term = new_final_term;
|
||||
}
|
||||
}
|
||||
computation.outputs = computation
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| cache.get(o).unwrap().clone())
|
||||
.collect();
|
||||
if precompute {
|
||||
let os = computation.precomputes.outputs().clone();
|
||||
for (name, old_term) in os {
|
||||
let new_term = cache.get(&old_term).unwrap().clone();
|
||||
computation.precomputes.change_output(&name, new_term);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ impl Constraints for Computation {
|
||||
outputs: assertions,
|
||||
metadata,
|
||||
precomputes: Default::default(),
|
||||
persistent_arrays: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
//! IR extensions
|
||||
|
||||
use super::ty::TypeErrorReason;
|
||||
use super::{Sort, Term, Value};
|
||||
use circ_hc::Node;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod poly;
|
||||
mod ram;
|
||||
mod sort;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
/// An extension operator. Not externally supported.
|
||||
///
|
||||
/// Often evaluatable, but not compilable.
|
||||
pub enum ExtOp {
|
||||
/// See [ram::eval]
|
||||
PersistentRamSplit,
|
||||
/// Given an array of tuples, returns a reordering such that the result is sorted.
|
||||
Sort,
|
||||
/// See [poly].
|
||||
UniqDeriGcd,
|
||||
}
|
||||
|
||||
impl ExtOp {
|
||||
/// Its arity
|
||||
pub fn arity(&self) -> Option<usize> {
|
||||
match self {
|
||||
ExtOp::PersistentRamSplit => Some(2),
|
||||
ExtOp::Sort => Some(2),
|
||||
ExtOp::UniqDeriGcd => Some(1),
|
||||
}
|
||||
}
|
||||
/// Type-check, given argument sorts
|
||||
pub fn check(&self, arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
match self {
|
||||
ExtOp::PersistentRamSplit => ram::check(arg_sorts),
|
||||
ExtOp::Sort => sort::check(arg_sorts),
|
||||
ExtOp::UniqDeriGcd => poly::check(arg_sorts),
|
||||
}
|
||||
}
|
||||
/// Evaluate, given argument values
|
||||
pub fn eval(&self, args: &[&Value]) -> Value {
|
||||
match self {
|
||||
ExtOp::PersistentRamSplit => ram::eval(args),
|
||||
ExtOp::Sort => sort::eval(args),
|
||||
ExtOp::UniqDeriGcd => poly::eval(args),
|
||||
}
|
||||
}
|
||||
/// Indicate which children of `t` must be typed to type `t`.
|
||||
pub fn check_dependencies(&self, t: &Term) -> Vec<Term> {
|
||||
t.cs().to_vec()
|
||||
}
|
||||
/// Parse, from bytes.
|
||||
pub fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
match bytes {
|
||||
b"persistent_ram_split" => Some(ExtOp::PersistentRamSplit),
|
||||
b"uniq_deri_gcd" => Some(ExtOp::UniqDeriGcd),
|
||||
b"sort" => Some(ExtOp::Sort),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
75
src/ir/term/ext/poly.rs
Normal file
75
src/ir/term/ext/poly.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! Operator UniqDeriGcd
|
||||
//!
|
||||
//! Given an array of (root, cond) tuples (root is a field element, cond is a boolean),
|
||||
//! define f(X) = prod_i (cond_i ? X - root_i : 1)
|
||||
//!
|
||||
//! Compute f'(X) and s,t s.t. fs + f't = 1. Return an array of coefficients for s and one for t
|
||||
//! (as a tuple).
|
||||
|
||||
use crate::ir::term::ty::*;
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// Type-check [super::ExtOp::UniqDeriGcd].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
if let &[pairs] = arg_sorts {
|
||||
let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?;
|
||||
let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?;
|
||||
let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?;
|
||||
if let &[root, cond] = &value_tup {
|
||||
eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?;
|
||||
eq_or(
|
||||
cond,
|
||||
&Sort::Bool,
|
||||
"UniqDeriGcd pairs: second element must be a bool",
|
||||
)?;
|
||||
let box_f = Box::new(f.clone());
|
||||
let arr = Sort::Array(box_f.clone(), box_f, size);
|
||||
Ok(Sort::Tuple(Box::new([arr.clone(), arr])))
|
||||
} else {
|
||||
// non-pair entries value
|
||||
Err(TypeErrorReason::Custom(
|
||||
"UniqDeriGcd: pairs value must be a pair".into(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
// wrong arg count
|
||||
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::UniqDeriGcd].
|
||||
#[cfg(feature = "poly")]
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
use rug_polynomial::ModPoly;
|
||||
let sort = args[0].sort().as_array().0.clone();
|
||||
let field = sort.as_pf().clone();
|
||||
let mut roots: Vec<Integer> = Vec::new();
|
||||
let deg = args[0].as_array().size;
|
||||
for t in args[0].as_array().values() {
|
||||
let tuple = t.as_tuple();
|
||||
let cond = tuple[1].as_bool();
|
||||
if cond {
|
||||
roots.push(tuple[0].as_pf().i());
|
||||
}
|
||||
}
|
||||
let p = ModPoly::with_roots(roots, field.modulus());
|
||||
let dp = p.derivative();
|
||||
let (g, s, t) = p.xgcd(&dp);
|
||||
assert_eq!(g.len(), 1);
|
||||
assert_eq!(g.get_coefficient(0), 1);
|
||||
let coeff_arr = |s: ModPoly| {
|
||||
let v: Vec<Value> = (0..deg)
|
||||
.map(|i| Value::Field(field.new_v(s.get_coefficient(i))))
|
||||
.collect();
|
||||
Value::Array(Array::from_vec(sort.clone(), sort.clone(), v))
|
||||
};
|
||||
let s_cs = coeff_arr(s);
|
||||
let t_cs = coeff_arr(t);
|
||||
Value::Tuple(Box::new([s_cs, t_cs]))
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::UniqDeriGcd].
|
||||
#[cfg(not(feature = "poly"))]
|
||||
pub fn eval(_args: &[&Value]) -> Value {
|
||||
panic!("Cannot evalute Op::UniqDeriGcd without 'poly' feature")
|
||||
}
|
||||
126
src/ir/term/ext/ram.rs
Normal file
126
src/ir/term/ext/ram.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Operator PersistentRamSplit
|
||||
|
||||
use crate::ir::term::ty::*;
|
||||
use crate::ir::term::*;
|
||||
use fxhash::FxHashSet as HashSet;
|
||||
|
||||
/// Type-check [super::ExtOp::PersistentRamSplit].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
if let &[entries, indices] = arg_sorts {
|
||||
let (key, value, size) = ty::array_or(entries, "PersistentRamSplit entries")?;
|
||||
let f = pf_or(key, "PersistentRamSplit entries: indices must be field")?;
|
||||
let value_tup = ty::tuple_or(value, "PersistentRamSplit entries: value must be a tuple")?;
|
||||
if let &[old, new] = &value_tup {
|
||||
eq_or(
|
||||
f,
|
||||
old,
|
||||
"PersistentRamSplit entries: value must be a field pair",
|
||||
)?;
|
||||
eq_or(
|
||||
f,
|
||||
new,
|
||||
"PersistentRamSplit entries: value must be a field pair",
|
||||
)?;
|
||||
let (i_key, i_value, i_size) = ty::array_or(indices, "PersistentRamSplit indices")?;
|
||||
eq_or(f, i_key, "PersistentRamSplit indices: key must be a field")?;
|
||||
eq_or(
|
||||
f,
|
||||
i_value,
|
||||
"PersistentRamSplit indices: value must be a field",
|
||||
)?;
|
||||
let n_touched = i_size.min(size);
|
||||
let n_ignored = size - n_touched;
|
||||
let box_f = Box::new(f.clone());
|
||||
let f_pair = Sort::Tuple(Box::new([f.clone(), f.clone()]));
|
||||
let ignored_entries_sort =
|
||||
Sort::Array(box_f.clone(), Box::new(f_pair.clone()), n_ignored);
|
||||
let selected_entries_sort = Sort::Array(box_f, Box::new(f_pair), n_touched);
|
||||
Ok(Sort::Tuple(Box::new([
|
||||
ignored_entries_sort,
|
||||
selected_entries_sort.clone(),
|
||||
selected_entries_sort,
|
||||
])))
|
||||
} else {
|
||||
// non-pair entries value
|
||||
Err(TypeErrorReason::Custom(
|
||||
"PersistentRamSplit: entries value must be a pair".into(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
// wrong arg count
|
||||
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::PersistentRamSplit].
|
||||
///
|
||||
/// Takes two arguments:
|
||||
///
|
||||
/// * entries: [(val_init_i, val_fin_i)] (len E)
|
||||
/// * indices: [idx_i] (len I)
|
||||
///
|
||||
/// assumes I < E and 0 <= idx_i < E.
|
||||
///
|
||||
/// Let dedup_i be idx_i with duplicates removed
|
||||
/// Let ext_i be dedup_i padded up to length I. The added elements are each i in 0.. (so long as i hasn't occured in ext_i already).
|
||||
/// Let
|
||||
///
|
||||
///
|
||||
/// Returns:
|
||||
/// * a bunch of sequences of index-value pairs:
|
||||
/// * untouched_entries (array field (tuple (field field)) (length I - E))
|
||||
/// * init_reads (array field (tuple (field field)) (length I))
|
||||
/// * fin_writes (array field (tuple (field field)) (length I))
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
let entries = &args[0].as_array().values();
|
||||
let (init_vals, fin_vals): (Vec<Value>, Vec<Value>) = entries
|
||||
.iter()
|
||||
.map(|t| (t.as_tuple()[0].clone(), t.as_tuple()[1].clone()))
|
||||
.unzip();
|
||||
let indices = &args[1].as_array().values();
|
||||
let num_accesses = indices.len();
|
||||
let field = args[0].as_array().key_sort.as_pf();
|
||||
let uniq_indices = {
|
||||
let mut uniq_indices = Vec::<usize>::new();
|
||||
let mut used_indices = HashSet::<usize>::default();
|
||||
for i in indices.iter().map(|i| i.as_usize().unwrap()).chain(0..) {
|
||||
if uniq_indices.len() == num_accesses {
|
||||
break;
|
||||
}
|
||||
if !used_indices.contains(&i) {
|
||||
uniq_indices.push(i);
|
||||
used_indices.insert(i);
|
||||
}
|
||||
}
|
||||
uniq_indices.sort();
|
||||
uniq_indices
|
||||
};
|
||||
let mut init_reads = Vec::new();
|
||||
let mut fin_writes = Vec::new();
|
||||
let mut untouched_entries = Vec::new();
|
||||
let mut j = 0;
|
||||
for (i, (init_val, fin_val)) in init_vals.into_iter().zip(fin_vals).enumerate() {
|
||||
if j < uniq_indices.len() && uniq_indices[j] == i {
|
||||
init_reads.push((i, init_val));
|
||||
fin_writes.push((i, fin_val));
|
||||
j += 1;
|
||||
} else {
|
||||
untouched_entries.push((i, init_val));
|
||||
}
|
||||
}
|
||||
let key_sort = Sort::Field(field.clone());
|
||||
let entry_to_vals =
|
||||
|e: (usize, Value)| Value::Tuple(Box::new([Value::Field(field.new_v(e.0)), e.1]));
|
||||
let vec_to_arr = |v: Vec<(usize, Value)>| {
|
||||
let vals: Vec<Value> = v.into_iter().map(entry_to_vals).collect();
|
||||
Value::Array(Array::from_vec(
|
||||
key_sort.clone(),
|
||||
vals.first().unwrap().sort(),
|
||||
vals,
|
||||
))
|
||||
};
|
||||
let init_reads = vec_to_arr(init_reads);
|
||||
let untouched_entries = vec_to_arr(untouched_entries);
|
||||
let fin_writes = vec_to_arr(fin_writes);
|
||||
Value::Tuple(vec![untouched_entries, init_reads, fin_writes].into_boxed_slice())
|
||||
}
|
||||
22
src/ir/term/ext/sort.rs
Normal file
22
src/ir/term/ext/sort.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Sort operator
|
||||
|
||||
use crate::ir::term::ty::*;
|
||||
use crate::ir::term::*;
|
||||
|
||||
/// Type-check [super::ExtOp::Sort].
|
||||
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
|
||||
array_or(arg_sorts[0], "sort argument").map(|_| arg_sorts[0].clone())
|
||||
}
|
||||
|
||||
/// Evaluate [super::ExtOp::Sort].
|
||||
pub fn eval(args: &[&Value]) -> Value {
|
||||
let sort = args[0].sort();
|
||||
let (key_sort, value_sort, _) = sort.as_array();
|
||||
let mut values: Vec<Value> = args[0].as_array().values();
|
||||
values.sort();
|
||||
Value::Array(Array::from_vec(
|
||||
key_sort.clone(),
|
||||
value_sort.clone(),
|
||||
values,
|
||||
))
|
||||
}
|
||||
151
src/ir/term/ext/test.rs
Normal file
151
src/ir/term/ext/test.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
//! Contains a number of constant terms for testing
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn op_sort_eval() {
|
||||
let t = text::parse_term(b"(declare () (sort (#l (mod 11) ((#t true true) (#t true false)))))");
|
||||
let actual_output = eval(&t, &Default::default());
|
||||
let expected_output = text::parse_value_map(
|
||||
b"(let ((output (#l (mod 11) ((#t true false) (#t true true))))) false)",
|
||||
);
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
|
||||
let t = text::parse_term(b"(declare () (sort (#l (mod 11) (#x0 #xf #x4 #x3))))");
|
||||
let actual_output = eval(&t, &Default::default());
|
||||
let expected_output =
|
||||
text::parse_value_map(b"(let ((output (#l (mod 11) (#x0 #x3 #x4 #xf)))) false)");
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
}
|
||||
|
||||
#[cfg(feature = "poly")]
|
||||
#[test]
|
||||
fn uniq_deri_gcd_eval() {
|
||||
let t = text::parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
|
||||
)
|
||||
(uniq_deri_gcd pairs))",
|
||||
);
|
||||
|
||||
let inputs = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
let actual_output = eval(&t, &inputs);
|
||||
let expected_output = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage
|
||||
(#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage
|
||||
))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
|
||||
let inputs = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
let actual_output = eval(&t, &inputs);
|
||||
let expected_output = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage
|
||||
(#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage
|
||||
))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn persistent_ram_split_eval() {
|
||||
let t = text::parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
|
||||
(indices (array (mod 17) (mod 17) 3))
|
||||
)
|
||||
(persistent_ram_split entries indices))",
|
||||
);
|
||||
|
||||
let inputs = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(entries (#l (mod 17) ( (#t #f0 #f1) (#t #f1 #f1) (#t #f2 #f3) (#t #f3 #f4) (#t #f4 #f4) )))
|
||||
(indices (#l (mod 17) (#f0 #f2 #f3)))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
let actual_output = eval(&t, &inputs);
|
||||
let expected_output = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( (#t #f1 #f1) (#t #f4 #f4) )) ; untouched
|
||||
(#l (mod 17) ( (#t #f0 #f0) (#t #f2 #f2) (#t #f3 #f3) )) ; init_reads
|
||||
(#l (mod 17) ( (#t #f0 #f1) (#t #f2 #f3) (#t #f3 #f4) )) ; fin_writes
|
||||
))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
dbg!(&actual_output.as_tuple()[0].as_array().default);
|
||||
dbg!(
|
||||
&expected_output.get("output").unwrap().as_tuple()[0]
|
||||
.as_array()
|
||||
.default
|
||||
);
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
|
||||
// duplicates
|
||||
let inputs = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(entries (#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) (#t #f3 #f3) (#t #f4 #f4) )))
|
||||
(indices (#l (mod 17) (#f1 #f1 #f1)))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
let actual_output = eval(&t, &inputs);
|
||||
let expected_output = text::parse_value_map(
|
||||
b"
|
||||
(set_default_modulus 17
|
||||
(let
|
||||
(
|
||||
(output (#t
|
||||
(#l (mod 17) ( (#t #f3 #f3) (#t #f4 #f4) )) ; untouched
|
||||
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f1) (#t #f2 #f2) )) ; init_reads
|
||||
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) )) ; fin_writes
|
||||
))
|
||||
) false))
|
||||
",
|
||||
);
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Machinery for formatting IR types
|
||||
use super::{
|
||||
Array, ComputationMetadata, Node, Op, PartyId, PostOrderIter, Sort, Term, TermMap, Value,
|
||||
ext, Array, ComputationMetadata, Node, Op, PartyId, PostOrderIter, Sort, Term, TermMap, Value,
|
||||
VariableMetadata,
|
||||
};
|
||||
use crate::cfg::{cfg, is_cfg_set};
|
||||
@@ -186,8 +186,22 @@ impl DisplayIr for Op {
|
||||
Op::IntBinPred(a) => write!(f, "{a}"),
|
||||
Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()),
|
||||
Op::PfChallenge(n, m) => write!(f, "(challenge {} {})", n, m.modulus()),
|
||||
Op::PfFitsInBits(n) => write!(f, "(pf_fits_in_bits {})", n),
|
||||
Op::Select => write!(f, "select"),
|
||||
Op::Store => write!(f, "store"),
|
||||
Op::CStore => write!(f, "cstore"),
|
||||
Op::Fill(key_sort, size) => {
|
||||
write!(f, "(fill ")?;
|
||||
key_sort.ir_fmt(f)?;
|
||||
write!(f, " {})", *size)
|
||||
}
|
||||
Op::Array(k, v) => {
|
||||
write!(f, "(array ")?;
|
||||
k.ir_fmt(f)?;
|
||||
write!(f, " ")?;
|
||||
v.ir_fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Op::Tuple => write!(f, "tuple"),
|
||||
Op::Field(i) => write!(f, "(field {i})"),
|
||||
Op::Update(i) => write!(f, "(update {i})"),
|
||||
@@ -198,6 +212,18 @@ impl DisplayIr for Op {
|
||||
}
|
||||
Op::Call(name, _, _) => write!(f, "fn:{name}"),
|
||||
Op::Rot(i) => write!(f, "(rot {i})"),
|
||||
Op::PfToBoolTrusted => write!(f, "pf2bool_trusted"),
|
||||
Op::ExtOp(o) => o.ir_fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayIr for ext::ExtOp {
|
||||
fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult {
|
||||
match self {
|
||||
ext::ExtOp::PersistentRamSplit => write!(f, "persistent_ram_split"),
|
||||
ext::ExtOp::UniqDeriGcd => write!(f, "uniq_deri_gcd"),
|
||||
ext::ExtOp::Sort => write!(f, "sort"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,17 +26,17 @@ use circ_fields::{FieldT, FieldV};
|
||||
pub use circ_hc::{Node, Table, Weak};
|
||||
use circ_opt::FieldToBv;
|
||||
use fxhash::{FxHashMap, FxHashSet};
|
||||
use log::debug;
|
||||
use log::{debug, trace};
|
||||
use rug::Integer;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::borrow::Borrow;
|
||||
use std::cell::Cell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod bv;
|
||||
pub mod dist;
|
||||
pub mod ext;
|
||||
pub mod extras;
|
||||
pub mod fmt;
|
||||
pub mod lin;
|
||||
@@ -46,6 +46,7 @@ pub mod text;
|
||||
pub mod ty;
|
||||
|
||||
pub use bv::BitVector;
|
||||
pub use ext::ExtOp;
|
||||
pub use ty::{check, check_rec, TypeError, TypeErrorReason};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
@@ -132,6 +133,8 @@ pub enum Op {
|
||||
///
|
||||
/// In IR evaluation, we sample deterministically based on a hash of the name.
|
||||
PfChallenge(String, FieldT),
|
||||
/// Requires the input pf element to fit in this many (unsigned) bits.
|
||||
PfFitsInBits(usize),
|
||||
|
||||
/// Integer n-ary operator
|
||||
IntNaryOp(IntNaryOp),
|
||||
@@ -146,6 +149,15 @@ pub enum Op {
|
||||
///
|
||||
/// Makes an array equal to `array`, but with `value` at `index`.
|
||||
Store,
|
||||
/// Quad-operator, with arguments (array, index, value, cond).
|
||||
///
|
||||
/// If `cond`, outputs an array equal to `array`, but with `value` at `index`.
|
||||
/// Otherwise, oupputs `array`.
|
||||
CStore,
|
||||
/// Makes an array of the indicated key sort with the indicated size, filled with the argument.
|
||||
Fill(Sort, usize),
|
||||
/// Create an array from (contiguous) values.
|
||||
Array(Sort, Sort),
|
||||
|
||||
/// Assemble n things into a tuple
|
||||
Tuple,
|
||||
@@ -163,6 +175,12 @@ pub enum Op {
|
||||
/// Cyclic right rotation of an array
|
||||
/// i.e. (Rot(1) [1,2,3,4]) --> ([4,1,2,3])
|
||||
Rot(usize),
|
||||
|
||||
/// Assume that the field element is 0 or 1, and treat it as a boolean.
|
||||
PfToBoolTrusted,
|
||||
|
||||
/// Extension operators. Used in compilation, but not externally supported
|
||||
ExtOp(ext::ExtOp),
|
||||
}
|
||||
|
||||
/// Boolean AND
|
||||
@@ -280,17 +298,23 @@ impl Op {
|
||||
Op::PfUnOp(_) => Some(1),
|
||||
Op::PfNaryOp(_) => None,
|
||||
Op::PfChallenge(_, _) => None,
|
||||
Op::PfFitsInBits(..) => Some(1),
|
||||
Op::IntNaryOp(_) => None,
|
||||
Op::IntBinPred(_) => Some(2),
|
||||
Op::UbvToPf(_) => Some(1),
|
||||
Op::Select => Some(2),
|
||||
Op::Store => Some(3),
|
||||
Op::CStore => Some(4),
|
||||
Op::Fill(..) => Some(1),
|
||||
Op::Array(..) => None,
|
||||
Op::Tuple => None,
|
||||
Op::Field(_) => Some(1),
|
||||
Op::Update(_) => Some(2),
|
||||
Op::Map(op) => op.arity(),
|
||||
Op::Call(_, args, _) => Some(args.len()),
|
||||
Op::Rot(_) => Some(1),
|
||||
Op::ExtOp(o) => o.arity(),
|
||||
Op::PfToBoolTrusted => Some(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -843,9 +867,9 @@ impl Sort {
|
||||
|
||||
#[track_caller]
|
||||
/// Unwrap the modulus of this prime field, panicking otherwise.
|
||||
pub fn as_pf(&self) -> Arc<Integer> {
|
||||
pub fn as_pf(&self) -> &FieldT {
|
||||
if let Sort::Field(fty) = self {
|
||||
fty.modulus_arc()
|
||||
fty
|
||||
} else {
|
||||
panic!("{} is not a field", self)
|
||||
}
|
||||
@@ -861,6 +885,42 @@ impl Sort {
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
/// Unwrap the constituent sorts of this array, panicking otherwise.
|
||||
pub fn as_array(&self) -> (&Sort, &Sort, usize) {
|
||||
if let Sort::Array(k, v, s) = self {
|
||||
(k, v, *s)
|
||||
} else {
|
||||
panic!("{} is not an array", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Is this an array?
|
||||
pub fn is_array(&self) -> bool {
|
||||
matches!(self, Sort::Array(..))
|
||||
}
|
||||
|
||||
/// The nth element of this sort.
|
||||
/// Only defined for booleans, bit-vectors, and field elements.
|
||||
#[track_caller]
|
||||
pub fn nth_elem(&self, n: usize) -> Term {
|
||||
match self {
|
||||
Sort::Bool => {
|
||||
assert!(n < 2);
|
||||
bool_lit([false, true][n])
|
||||
}
|
||||
Sort::BitVector(w) => {
|
||||
assert!(n < (1 << w));
|
||||
bv_lit(n, *w)
|
||||
}
|
||||
Sort::Field(f) => {
|
||||
debug_assert!(&Integer::from(n) < f.modulus());
|
||||
pf_lit(f.new_v(n))
|
||||
}
|
||||
_ => panic!("Can't get nth element of sort {}", self),
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over the elements of this sort (as IR Terms).
|
||||
/// Only defined for booleans, bit-vectors, and field elements.
|
||||
#[track_caller]
|
||||
@@ -1238,8 +1298,15 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
|
||||
/// Helper function for eval function. Handles a single term
|
||||
fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, t: Term) -> Value {
|
||||
let args: Vec<&Value> = t.cs().iter().map(|c| vs.get(c).unwrap()).collect();
|
||||
trace!("Eval {} on {:?}", t.op(), args);
|
||||
let v = eval_op(t.op(), &args, h);
|
||||
debug!("Eval {}\nAs {}", t, v);
|
||||
trace!("=> {}", v);
|
||||
if let Value::Bool(false) = &v {
|
||||
trace!("term {}", t);
|
||||
for v in extras::free_variables(t.clone()) {
|
||||
trace!(" {} = {}", v, h.get(&v).unwrap());
|
||||
}
|
||||
}
|
||||
vs.insert(t, v.clone());
|
||||
v
|
||||
}
|
||||
@@ -1393,6 +1460,9 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
|
||||
}),
|
||||
Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())),
|
||||
Op::PfChallenge(name, field) => Value::Field(pf_challenge(name, field)),
|
||||
Op::PfFitsInBits(n_bits) => {
|
||||
Value::Bool(args[0].as_pf().i().signed_bits() <= *n_bits as u32)
|
||||
}
|
||||
// tuple
|
||||
Op::Tuple => Value::Tuple(args.iter().map(|a| (*a).clone()).collect()),
|
||||
Op::Field(i) => {
|
||||
@@ -1415,6 +1485,31 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
|
||||
let v = args[2].clone();
|
||||
Value::Array(a.store(i, v))
|
||||
}
|
||||
Op::CStore => {
|
||||
let a = args[0].as_array().clone();
|
||||
let i = args[1].clone();
|
||||
let v = args[2].clone();
|
||||
let c = args[3].as_bool();
|
||||
if c {
|
||||
Value::Array(a.store(i, v))
|
||||
} else {
|
||||
Value::Array(a)
|
||||
}
|
||||
}
|
||||
Op::Fill(key_sort, size) => {
|
||||
let v = args[0].clone();
|
||||
Value::Array(Array::new(
|
||||
key_sort.clone(),
|
||||
Box::new(v),
|
||||
Default::default(),
|
||||
*size,
|
||||
))
|
||||
}
|
||||
Op::Array(key, value) => Value::Array(Array::from_vec(
|
||||
key.clone(),
|
||||
value.clone(),
|
||||
args.iter().cloned().cloned().collect(),
|
||||
)),
|
||||
Op::Select => {
|
||||
let a = args[0].as_array().clone();
|
||||
let i = args[1];
|
||||
@@ -1476,6 +1571,12 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
|
||||
}
|
||||
Value::Array(res)
|
||||
}
|
||||
Op::PfToBoolTrusted => {
|
||||
let v = args[0].as_pf().i();
|
||||
assert!(v == 0 || v == 1);
|
||||
Value::Bool(v == 1)
|
||||
}
|
||||
Op::ExtOp(o) => o.eval(args),
|
||||
|
||||
o => unimplemented!("eval: {:?}", o),
|
||||
}
|
||||
@@ -1488,10 +1589,22 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
|
||||
/// * a key sort, as all arrays do. This sort must be iterable (i.e., bool, int, bit-vector, or field).
|
||||
/// * a value sort, for the array's default
|
||||
pub fn make_array(key_sort: Sort, value_sort: Sort, i: Vec<Term>) -> Term {
|
||||
let d = Sort::Array(Box::new(key_sort.clone()), Box::new(value_sort), i.len()).default_term();
|
||||
i.into_iter()
|
||||
.zip(key_sort.elems_iter())
|
||||
.fold(d, |arr, (val, idx)| term(Op::Store, vec![arr, idx, val]))
|
||||
term(Op::Array(key_sort, value_sort), i)
|
||||
}
|
||||
|
||||
/// Make a sequence of terms from an array.
|
||||
///
|
||||
/// Requires
|
||||
///
|
||||
/// * an array term
|
||||
pub fn unmake_array(a: Term) -> Vec<Term> {
|
||||
let sort = check(&a);
|
||||
let (key_sort, _, size) = sort.as_array();
|
||||
key_sort
|
||||
.elems_iter()
|
||||
.take(size)
|
||||
.map(|idx| term(Op::Select, vec![a.clone(), idx]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Make a term with no arguments, just an operator.
|
||||
@@ -1547,6 +1660,29 @@ macro_rules! term {
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Make a term, with clones.
|
||||
///
|
||||
/// Syntax:
|
||||
///
|
||||
/// * without children: `term![OP]`
|
||||
/// * with children: `term![OP; ARG0, ARG1, ... ]`
|
||||
/// * Note the semi-colon
|
||||
macro_rules! term_c {
|
||||
($x:expr; $($y:expr),+) => {
|
||||
{
|
||||
let mut args = Vec::new();
|
||||
#[allow(clippy::vec_init_then_push)]
|
||||
{
|
||||
$(
|
||||
args.push(($y).clone());
|
||||
)+
|
||||
}
|
||||
term($x, args)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Map from terms
|
||||
pub type TermMap<T> = FxHashMap<Term, T>;
|
||||
/// LRU cache of terms (like TermMap, but limited size)
|
||||
@@ -1703,14 +1839,27 @@ impl ComputationMetadata {
|
||||
self.vars.insert(metadata.name.clone(), metadata);
|
||||
}
|
||||
|
||||
/// Lookup metadata
|
||||
#[track_caller]
|
||||
fn lookup<Q: std::borrow::Borrow<str> + ?Sized>(&self, name: &Q) -> &VariableMetadata {
|
||||
pub fn lookup<Q: std::borrow::Borrow<str> + ?Sized>(&self, name: &Q) -> &VariableMetadata {
|
||||
let n = name.borrow();
|
||||
self.vars
|
||||
.get(n)
|
||||
.unwrap_or_else(|| panic!("Missing input {} in inputs{:#?}", n, self.vars))
|
||||
}
|
||||
|
||||
/// Lookup metadata
|
||||
#[track_caller]
|
||||
pub fn lookup_mut<Q: std::borrow::Borrow<str> + ?Sized>(
|
||||
&mut self,
|
||||
name: &Q,
|
||||
) -> &mut VariableMetadata {
|
||||
let n = name.borrow();
|
||||
self.vars
|
||||
.get_mut(n)
|
||||
.unwrap_or_else(|| panic!("Missing input {}", n))
|
||||
}
|
||||
|
||||
/// Returns None if the value is public. Otherwise, the unique party that knows it.
|
||||
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
|
||||
self.lookup(input_name).vis
|
||||
@@ -1890,6 +2039,11 @@ pub struct Computation {
|
||||
pub metadata: ComputationMetadata,
|
||||
/// Pre-computations
|
||||
pub precomputes: precomp::PreComp,
|
||||
/// Persistent Arrays. [(name, term)]
|
||||
/// where:
|
||||
/// * name: a variable name (array type) indicating the input state
|
||||
/// * name: a term indicating the output state
|
||||
pub persistent_arrays: Vec<(String, Term)>,
|
||||
}
|
||||
|
||||
impl Computation {
|
||||
@@ -1911,7 +2065,13 @@ impl Computation {
|
||||
party: Option<PartyId>,
|
||||
precompute: Option<Term>,
|
||||
) -> Term {
|
||||
debug!("Var: {} : {} (visibility: {:?})", name, s, party);
|
||||
debug!(
|
||||
"Var: {} : {} (visibility: {:?}) (precompute: {})",
|
||||
name,
|
||||
s,
|
||||
party,
|
||||
precompute.is_some()
|
||||
);
|
||||
self.metadata.new_input(name.to_owned(), party, s.clone());
|
||||
if let Some(p) = precompute {
|
||||
assert_eq!(&s, &check(&p));
|
||||
@@ -1920,6 +2080,31 @@ impl Computation {
|
||||
leaf_term(Op::Var(name.to_owned(), s))
|
||||
}
|
||||
|
||||
/// Create a new variable with the given metadata.
|
||||
///
|
||||
/// If `precompute` is set, that precomputation is added to give a value for this variable.
|
||||
/// Otherwise, the variable is assumed to be an input.
|
||||
pub fn new_var_metadata(
|
||||
&mut self,
|
||||
metadata: VariableMetadata,
|
||||
precompute: Option<Term>,
|
||||
) -> Term {
|
||||
debug!(
|
||||
"Var: {} : {:?} (precompute {})",
|
||||
metadata.name,
|
||||
metadata,
|
||||
precompute.is_some()
|
||||
);
|
||||
let sort = metadata.sort.clone();
|
||||
let name = metadata.name.clone();
|
||||
self.metadata.new_input_from_meta(metadata);
|
||||
if let Some(p) = precompute {
|
||||
assert_eq!(&sort, &check(&p));
|
||||
self.precomputes.add_output(name.clone(), p);
|
||||
}
|
||||
leaf_term(Op::Var(name, sort))
|
||||
}
|
||||
|
||||
/// Add a new input `new_input_var` to this computation,
|
||||
/// whose value is determined by `precomp`: a term over existing inputs.
|
||||
///
|
||||
@@ -1946,6 +2131,49 @@ impl Computation {
|
||||
self.new_var(&new_input_var, sort, vis, Some(precomp));
|
||||
}
|
||||
|
||||
/// Intialize a new persistent array.
|
||||
pub fn start_persistent_array(
|
||||
&mut self,
|
||||
var: &str,
|
||||
size: usize,
|
||||
field: FieldT,
|
||||
party: PartyId,
|
||||
) -> Term {
|
||||
let f = Sort::Field(field);
|
||||
let s = Sort::Array(Box::new(f.clone()), Box::new(f), size);
|
||||
let md = VariableMetadata {
|
||||
name: var.to_owned(),
|
||||
vis: Some(party),
|
||||
sort: s,
|
||||
committed: true,
|
||||
..Default::default()
|
||||
};
|
||||
let term = self.new_var_metadata(md, None);
|
||||
|
||||
// we'll replace dummy later
|
||||
let dummy = bool_lit(true);
|
||||
self.persistent_arrays.push((var.into(), dummy));
|
||||
term
|
||||
}
|
||||
|
||||
/// Record the final state of a persistent array. Should be called once per array, with the
|
||||
/// same name as [Computation::start_persistent_array].
|
||||
pub fn end_persistent_array(&mut self, var: &str, final_state: Term) {
|
||||
for (name, t) in &mut self.persistent_arrays {
|
||||
if name == var {
|
||||
assert_eq!(*t, bool_lit(true));
|
||||
*t = final_state;
|
||||
return;
|
||||
}
|
||||
}
|
||||
panic!("No existing persistent memory {}", var)
|
||||
}
|
||||
|
||||
/// Make a vector of existing variables a commitment.
|
||||
pub fn add_commitment(&mut self, names: Vec<String>) {
|
||||
self.metadata.add_commitment(names);
|
||||
}
|
||||
|
||||
/// Change the sort of a variables
|
||||
pub fn remove_var(&mut self, var: &str) {
|
||||
self.metadata.remove_var(var);
|
||||
@@ -1964,6 +2192,7 @@ impl Computation {
|
||||
outputs: Vec::new(),
|
||||
metadata: ComputationMetadata::default(),
|
||||
precomputes: Default::default(),
|
||||
persistent_arrays: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1992,6 +2221,29 @@ impl Computation {
|
||||
terms.pop();
|
||||
terms.into_iter()
|
||||
}
|
||||
|
||||
/// Evaluate the precompute, then this computation.
|
||||
pub fn eval_all(&self, values: &FxHashMap<String, Value>) -> Vec<Value> {
|
||||
let mut values = values.clone();
|
||||
|
||||
// set all challenges to 1.
|
||||
for v in self.metadata.vars.values() {
|
||||
if v.random {
|
||||
let field = v.sort.as_pf();
|
||||
let value = Value::Field(pf_challenge(&v.name, field));
|
||||
values.insert(v.name.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
values = self.precomputes.eval(&values);
|
||||
|
||||
let mut cache = Default::default();
|
||||
let mut outputs = Vec::new();
|
||||
for o in &self.outputs {
|
||||
outputs.push(eval_cached(o, &values, &mut cache).clone());
|
||||
}
|
||||
outputs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
@@ -7,6 +7,8 @@ use fxhash::{FxHashMap, FxHashSet};
|
||||
|
||||
use crate::ir::term::*;
|
||||
|
||||
use log::trace;
|
||||
|
||||
/// A "precomputation".
|
||||
///
|
||||
/// Expresses a computation to be run in advance by a single party.
|
||||
@@ -47,6 +49,10 @@ impl PreComp {
|
||||
let old = self.outputs.insert(name, value);
|
||||
assert!(old.is_none());
|
||||
}
|
||||
/// Overwrite a step
|
||||
pub fn change_output(&mut self, name: &str, value: Term) {
|
||||
*self.outputs.get_mut(name).unwrap() = value;
|
||||
}
|
||||
/// Retain only the parts of this precomputation that can be evaluated from
|
||||
/// the `known` inputs.
|
||||
pub fn restrict_to_inputs(&mut self, known: FxHashSet<String>) {
|
||||
@@ -84,7 +90,9 @@ impl PreComp {
|
||||
for (o_name, _o_sort) in &self.sequence {
|
||||
let o = self.outputs.get(o_name).unwrap();
|
||||
eval_cached(o, &env, &mut value_cache);
|
||||
env.insert(o_name.clone(), value_cache.get(o).unwrap().clone());
|
||||
let value = value_cache.get(o).unwrap().clone();
|
||||
trace!("pre {o_name} => {value}");
|
||||
env.insert(o_name.clone(), value);
|
||||
}
|
||||
env
|
||||
}
|
||||
|
||||
@@ -626,3 +626,11 @@ pub fn bv_sub_tests() -> Vec<Term> {
|
||||
],
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pf2bool_eval() {
|
||||
let t = text::parse_term(b"(declare () (pf2bool_trusted (ite false #f1m11 #f0m11)))");
|
||||
let actual_output = eval(&t, &Default::default());
|
||||
let expected_output = text::parse_value_map(b"(let ((output false)) false)");
|
||||
assert_eq!(&actual_output, expected_output.get("output").unwrap());
|
||||
}
|
||||
|
||||
@@ -28,6 +28,11 @@
|
||||
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
|
||||
//! * OUTPUTS is `((X1 S1) .. (Xn Sn))`
|
||||
//! * TUPLE_TERM is a tuple of the same arity as the output
|
||||
//! * ARRAYS (optional): `(persistent_arrays ARRAY*)`:
|
||||
//! * ARRAY is `(X S T)`
|
||||
//! * X is the name of the inital state
|
||||
//! * S is the size
|
||||
//! * T is the state (final)
|
||||
//! * Sort `S`:
|
||||
//! * `bool`
|
||||
//! * `f32`
|
||||
@@ -286,12 +291,22 @@ impl<'src> IrInterp<'src> {
|
||||
Leaf(Ident, b"intmul") => Ok(INT_MUL),
|
||||
Leaf(Ident, b"select") => Ok(Op::Select),
|
||||
Leaf(Ident, b"store") => Ok(Op::Store),
|
||||
Leaf(Ident, b"cstore") => Ok(Op::CStore),
|
||||
Leaf(Ident, b"tuple") => Ok(Op::Tuple),
|
||||
Leaf(Ident, b"pf2bool_trusted") => Ok(Op::PfToBoolTrusted),
|
||||
Leaf(Ident, bytes) => {
|
||||
if let Some(e) = ext::ExtOp::parse(bytes) {
|
||||
Ok(Op::ExtOp(e))
|
||||
} else {
|
||||
todo!("Unparsed op: {}", tt)
|
||||
}
|
||||
}
|
||||
List(tts) => match &tts[..] {
|
||||
[Leaf(Ident, b"extract"), a, b] => Ok(Op::BvExtract(self.usize(a), self.usize(b))),
|
||||
[Leaf(Ident, b"uext"), a] => Ok(Op::BvUext(self.usize(a))),
|
||||
[Leaf(Ident, b"sext"), a] => Ok(Op::BvSext(self.usize(a))),
|
||||
[Leaf(Ident, b"pf2bv"), a] => Ok(Op::PfToBv(self.usize(a))),
|
||||
[Leaf(Ident, b"pf_fits_in_bits"), a] => Ok(Op::PfFitsInBits(self.usize(a))),
|
||||
[Leaf(Ident, b"bit"), a] => Ok(Op::BvBit(self.usize(a))),
|
||||
[Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))),
|
||||
[Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))),
|
||||
@@ -300,9 +315,13 @@ impl<'src> IrInterp<'src> {
|
||||
self.ident_string(name),
|
||||
FieldT::from(self.int(field)),
|
||||
)),
|
||||
[Leaf(Ident, b"array"), k, v] => Ok(Op::Array(self.sort(k), self.sort(v))),
|
||||
[Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))),
|
||||
[Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))),
|
||||
[Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))),
|
||||
[Leaf(Ident, b"fill"), key_sort, size] => {
|
||||
Ok(Op::Fill(self.sort(key_sort), self.usize(size)))
|
||||
}
|
||||
_ => todo!("Unparsed op: {}", tt),
|
||||
},
|
||||
_ => todo!("Unparsed op: {}", tt),
|
||||
@@ -691,13 +710,31 @@ impl<'src> IrInterp<'src> {
|
||||
assert!(tts.len() >= 3);
|
||||
let (metadata, input_names) = self.metadata(&tts[0]);
|
||||
let precomputes = self.precompute(&tts[1]);
|
||||
let iter = tts.iter().skip(2);
|
||||
let mut persistent_arrays = Vec::new();
|
||||
let mut skip_one = false;
|
||||
if let List(tts_inner) = &tts[2] {
|
||||
if tts_inner[0] == Leaf(Token::Ident, b"persistent_arrays") {
|
||||
skip_one = true;
|
||||
for tti in tts_inner.iter().skip(1) {
|
||||
let ttis = self.unwrap_list(tti, "persistent_arrays");
|
||||
let id = self.ident_string(&ttis[0]);
|
||||
let _size = self.usize(&ttis[1]);
|
||||
let term = self.term(&ttis[2]);
|
||||
persistent_arrays.push((id, term));
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut iter = tts.iter().skip(2);
|
||||
if skip_one {
|
||||
iter.next();
|
||||
}
|
||||
let outputs = iter.map(|tti| self.term(tti)).collect();
|
||||
self.unbind(input_names);
|
||||
Computation {
|
||||
outputs,
|
||||
metadata,
|
||||
precomputes,
|
||||
persistent_arrays,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -816,6 +853,14 @@ pub fn serialize_computation(c: &Computation) -> String {
|
||||
let mut out = String::new();
|
||||
writeln!(&mut out, "(computation \n{}", c.metadata).unwrap();
|
||||
writeln!(&mut out, "{}", serialize_precompute(&c.precomputes)).unwrap();
|
||||
if !c.persistent_arrays.is_empty() {
|
||||
writeln!(&mut out, "(persistent_arrays").unwrap();
|
||||
for (name, term) in &c.persistent_arrays {
|
||||
let size = check(term).as_array().2;
|
||||
writeln!(&mut out, " ({name} {size} {})", serialize_term(term)).unwrap();
|
||||
}
|
||||
writeln!(&mut out, "\n)").unwrap();
|
||||
}
|
||||
for o in &c.outputs {
|
||||
writeln!(&mut out, "\n {}", serialize_term(o)).unwrap();
|
||||
}
|
||||
@@ -907,6 +952,43 @@ mod test {
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arr_cstore_roundtrip() {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(a bool)
|
||||
(b bool)
|
||||
(c bool)
|
||||
(A (array bool bool 1))
|
||||
)
|
||||
(let (
|
||||
(B (cstore A a b c))
|
||||
) (xor (select B a)
|
||||
(select (#a (bv 4) false 4 ((#b0000 true))) #b0000))))",
|
||||
);
|
||||
let s = serialize_term(&t);
|
||||
let t2 = parse_term(s.as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arr_op_roundtrip() {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(a bool)
|
||||
(b bool)
|
||||
(A (array bool bool 1))
|
||||
)
|
||||
(= A ((array bool bool) a))
|
||||
)",
|
||||
);
|
||||
let s = serialize_term(&t);
|
||||
let t2 = parse_term(s.as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tup_roundtrip() {
|
||||
let t = parse_term(
|
||||
@@ -1008,6 +1090,38 @@ mod test {
|
||||
assert_eq!(c, c2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn computation_roundtrip_persistent_arrays() {
|
||||
let c = parse_computation(
|
||||
b"
|
||||
(computation
|
||||
(metadata
|
||||
(parties P V)
|
||||
(inputs
|
||||
(a bool (party 0) (random) (round 1))
|
||||
(b bool)
|
||||
(A (tuple bool bool))
|
||||
(x bool (party 0) (committed))
|
||||
)
|
||||
(commitments)
|
||||
)
|
||||
(precompute
|
||||
((c bool) (d bool))
|
||||
((a bool))
|
||||
(tuple (not (and c d)))
|
||||
)
|
||||
(persistent_arrays (AA 2 (#a (bv 4) false 4 ((#b0000 true)))))
|
||||
(let (
|
||||
(B ((update 1) A b))
|
||||
) (xor ((field 1) B)
|
||||
((field 0) (#t false false #b0000 true))))
|
||||
)",
|
||||
);
|
||||
let s = serialize_computation(&c);
|
||||
let c2 = parse_computation(s.as_bytes());
|
||||
assert_eq!(c, c2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn challenge_roundtrip() {
|
||||
let t = parse_term(b"(declare ((a bool) (b bool)) ((challenge hithere 17) a b))");
|
||||
@@ -1015,4 +1129,73 @@ mod test {
|
||||
let t2 = parse_term(s.as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn persistent_ram_split_roundtrip() {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
|
||||
(indices (array (mod 17) (mod 17) 3))
|
||||
)
|
||||
(persistent_ram_split entries indices))",
|
||||
);
|
||||
let s = serialize_term(&t);
|
||||
println!("{s}");
|
||||
let t2 = parse_term(s.as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_value_equiv_to_array() {
|
||||
let t_array = parse_term(b"(declare () (#a (bv 4) #x0 3 ((#x0 #x0) (#x1 #x1) (#x2 #x4))))");
|
||||
let t_list = parse_term(b"(declare () (#l (bv 4) (#x0 #x1 #x4)))");
|
||||
assert_eq!(t_array, t_list);
|
||||
let s = serialize_term(&t_array);
|
||||
let t_roundtripped = parse_term(s.as_bytes());
|
||||
assert_eq!(t_array, t_roundtripped);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pf2bool_trusted_rountrip() {
|
||||
let t = parse_term(b"(declare ((a bool)) (pf2bool_trusted (ite a #f1m11 #f0m11)))");
|
||||
let t2 = parse_term(serialize_term(&t).as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn op_sort_rountrip() {
|
||||
let t = parse_term(b"(declare () (sort (#l (mod 3) ((#t true true) (#t true false)))))");
|
||||
let t2 = parse_term(serialize_term(&t).as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fill_roundtrip() {
|
||||
let t = parse_term(b"(declare () ((fill (bv 4) 3) #x00))");
|
||||
let t2 = parse_term(serialize_term(&t).as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pf_fits_in_bits_rountrip() {
|
||||
let t = parse_term(b"(declare ((a bool)) ((pf_fits_in_bits 4) (ite a #f1m11 #f0m11)))");
|
||||
let t2 = parse_term(serialize_term(&t).as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uniq_deri_gcd_roundtrip() {
|
||||
let t = parse_term(
|
||||
b"
|
||||
(declare (
|
||||
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
|
||||
)
|
||||
(uniq_deri_gcd pairs))",
|
||||
);
|
||||
let s = serialize_term(&t);
|
||||
println!("{s}");
|
||||
let t2 = parse_term(s.as_bytes());
|
||||
assert_eq!(t, t2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,14 +60,20 @@ fn check_dependencies(t: &Term) -> Vec<Term> {
|
||||
Op::IntBinPred(_) => Vec::new(),
|
||||
Op::UbvToPf(_) => Vec::new(),
|
||||
Op::PfChallenge(_, _) => Vec::new(),
|
||||
Op::PfFitsInBits(_) => Vec::new(),
|
||||
Op::Select => vec![t.cs()[0].clone()],
|
||||
Op::Store => vec![t.cs()[0].clone()],
|
||||
Op::Array(..) => Vec::new(),
|
||||
Op::CStore => vec![t.cs()[0].clone()],
|
||||
Op::Fill(..) => vec![t.cs()[0].clone()],
|
||||
Op::Tuple => t.cs().to_vec(),
|
||||
Op::Field(_) => vec![t.cs()[0].clone()],
|
||||
Op::Update(_i) => vec![t.cs()[0].clone()],
|
||||
Op::Map(_) => t.cs().to_vec(),
|
||||
Op::Call(_, _, _) => Vec::new(),
|
||||
Op::Rot(_) => vec![t.cs()[0].clone()],
|
||||
Op::PfToBoolTrusted => Vec::new(),
|
||||
Op::ExtOp(o) => o.check_dependencies(t),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,8 +136,20 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
|
||||
Op::IntBinPred(_) => Ok(Sort::Bool),
|
||||
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
|
||||
Op::PfChallenge(_, m) => Ok(Sort::Field(m.clone())),
|
||||
Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v)| v.clone()),
|
||||
Op::PfFitsInBits(_) => Ok(Sort::Bool),
|
||||
Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v, _)| v.clone()),
|
||||
Op::Store => Ok(get_ty(&t.cs()[0]).clone()),
|
||||
Op::Array(k, v) => Ok(Sort::Array(
|
||||
Box::new(k.clone()),
|
||||
Box::new(v.clone()),
|
||||
t.cs().len(),
|
||||
)),
|
||||
Op::CStore => Ok(get_ty(&t.cs()[0]).clone()),
|
||||
Op::Fill(key_sort, size) => Ok(Sort::Array(
|
||||
Box::new(key_sort.clone()),
|
||||
Box::new(get_ty(&t.cs()[0]).clone()),
|
||||
*size,
|
||||
)),
|
||||
Op::Tuple => Ok(Sort::Tuple(t.cs().iter().map(get_ty).cloned().collect())),
|
||||
Op::Field(i) => {
|
||||
let sort = get_ty(&t.cs()[0]);
|
||||
@@ -165,7 +183,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
|
||||
|
||||
for i in 0..arg_cnt {
|
||||
match array_or(get_ty(&t.cs()[i]), "map inputs") {
|
||||
Ok((_, v)) => {
|
||||
Ok((_, v, _)) => {
|
||||
arg_sorts_to_inner_op.push(v);
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -183,6 +201,11 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
|
||||
}
|
||||
Op::Call(_, _, ret) => Ok(ret.clone()),
|
||||
Op::Rot(_) => Ok(get_ty(&t.cs()[0]).clone()),
|
||||
Op::PfToBoolTrusted => Ok(Sort::Bool),
|
||||
Op::ExtOp(o) => {
|
||||
let args_sorts: Vec<&Sort> = t.cs().iter().map(|c| get_ty(c)).collect();
|
||||
o.check(&args_sorts)
|
||||
}
|
||||
o => Err(TypeErrorReason::Custom(format!("other operator: {o}"))),
|
||||
}
|
||||
}
|
||||
@@ -335,6 +358,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
|
||||
}
|
||||
(Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())),
|
||||
(Op::PfChallenge(_, m), _) => Ok(Sort::Field(m.clone())),
|
||||
(Op::PfFitsInBits(_), &[a]) => pf_or(a, "pf fits in bits").map(|_| Sort::Bool),
|
||||
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
|
||||
(Op::IntNaryOp(_), a) => {
|
||||
let ctx = "int nary op";
|
||||
@@ -349,6 +373,21 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
|
||||
(Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store")
|
||||
.and_then(|_| eq_or(v, b, "store"))
|
||||
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
|
||||
(Op::CStore, &[Sort::Array(k, v, n), a, b, c]) => eq_or(k, a, "cstore")
|
||||
.and_then(|_| eq_or(v, b, "cstore"))
|
||||
.and_then(|_| bool_or(c, "cstore"))
|
||||
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
|
||||
(Op::Fill(key_sort, size), &[v]) => Ok(Sort::Array(
|
||||
Box::new(key_sort.clone()),
|
||||
Box::new(v.clone()),
|
||||
*size,
|
||||
)),
|
||||
(Op::Array(k, v), a) => {
|
||||
let ctx = "array op";
|
||||
a.iter()
|
||||
.try_fold((), |(), ai| eq_or(v, ai, ctx))
|
||||
.map(|_| Sort::Array(Box::new(k.clone()), Box::new(v.clone()), a.len()))
|
||||
}
|
||||
(Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())),
|
||||
(Op::Field(i), &[a]) => tuple_or(a, "tuple field access").and_then(|t| {
|
||||
if i < &t.len() {
|
||||
@@ -419,6 +458,8 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
|
||||
(Op::Rot(_), &[Sort::Array(k, v, n)]) => bv_or(k, "rot key")
|
||||
.and_then(|_| bv_or(v, "rot val"))
|
||||
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
|
||||
(Op::PfToBoolTrusted, &[k]) => pf_or(k, "pf to bool argument").map(|_| Sort::Bool),
|
||||
(Op::ExtOp(o), _) => o.check(a),
|
||||
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
|
||||
}
|
||||
}
|
||||
@@ -525,9 +566,12 @@ fn int_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReaso
|
||||
}
|
||||
}
|
||||
|
||||
fn array_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<(&'a Sort, &'a Sort), TypeErrorReason> {
|
||||
if let Sort::Array(k, v, _) = a {
|
||||
Ok((k, v))
|
||||
pub(super) fn array_or<'a>(
|
||||
a: &'a Sort,
|
||||
ctx: &'static str,
|
||||
) -> Result<(&'a Sort, &'a Sort, usize), TypeErrorReason> {
|
||||
if let Sort::Array(k, v, size) = a {
|
||||
Ok((k, v, *size))
|
||||
} else {
|
||||
Err(TypeErrorReason::ExpectedArray(a.clone(), ctx))
|
||||
}
|
||||
@@ -559,21 +603,21 @@ fn fp_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
|
||||
}
|
||||
}
|
||||
|
||||
fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
|
||||
pub(super) fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
|
||||
match a {
|
||||
Sort::Field(_) => Ok(a),
|
||||
_ => Err(TypeErrorReason::ExpectedPf(a.clone(), ctx)),
|
||||
}
|
||||
}
|
||||
|
||||
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], TypeErrorReason> {
|
||||
pub(super) fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], TypeErrorReason> {
|
||||
match a {
|
||||
Sort::Tuple(a) => Ok(a),
|
||||
_ => Err(TypeErrorReason::ExpectedTuple(ctx)),
|
||||
}
|
||||
}
|
||||
|
||||
fn eq_or(a: &Sort, b: &Sort, ctx: &'static str) -> Result<(), TypeErrorReason> {
|
||||
pub(super) fn eq_or(a: &Sort, b: &Sort, ctx: &'static str) -> Result<(), TypeErrorReason> {
|
||||
if a == b {
|
||||
Ok(())
|
||||
} else {
|
||||
|
||||
@@ -224,12 +224,11 @@ mod tests {
|
||||
);
|
||||
let costs = CostModel::from_opa_cost_file(&p);
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![BV_MUL;
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
|
||||
leaf_term(Op::Var("b".to_owned(), Sort::BitVector(32)))
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let _assignment = build_ilp(&cs, &costs);
|
||||
}
|
||||
@@ -242,7 +241,6 @@ mod tests {
|
||||
);
|
||||
let costs = CostModel::from_opa_cost_file(&p);
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![Op::Eq;
|
||||
term![BV_MUL;
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
|
||||
@@ -268,7 +266,7 @@ mod tests {
|
||||
],
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32)))
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let assignment = build_ilp(&cs, &costs);
|
||||
// Big enough to do the math with arith
|
||||
@@ -288,7 +286,6 @@ mod tests {
|
||||
);
|
||||
let costs = CostModel::from_opa_cost_file(&p);
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![Op::Eq;
|
||||
term![BV_MUL;
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
|
||||
@@ -305,7 +302,7 @@ mod tests {
|
||||
],
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32)))
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let assignment = build_ilp(&cs, &costs);
|
||||
// All yao
|
||||
|
||||
@@ -708,8 +708,7 @@ mod test {
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::Bool)),
|
||||
leaf_term(Op::Var("b".to_owned(), Sort::Bool))],
|
||||
],
|
||||
metadata: ComputationMetadata::default(),
|
||||
precomputes: Default::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let r = ilp.solve(default_solver).unwrap().1;
|
||||
@@ -850,8 +849,7 @@ mod test {
|
||||
fn trivial_bv_opt() {
|
||||
let cs = Computation {
|
||||
outputs: vec![leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)))],
|
||||
metadata: ComputationMetadata::default(),
|
||||
precomputes: Default::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let (max, vars) = ilp.solve(default_solver).unwrap();
|
||||
@@ -866,8 +864,7 @@ mod test {
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
|
||||
bv_lit(1,4)
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
precomputes: Default::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let (max, vars) = ilp.solve(default_solver).unwrap();
|
||||
@@ -877,12 +874,11 @@ mod test {
|
||||
#[test]
|
||||
fn mul2_bv_opt() {
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![BV_MUL;
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
|
||||
bv_lit(2,4)
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let (max, _vars) = ilp.solve(default_solver).unwrap();
|
||||
@@ -891,7 +887,6 @@ mod test {
|
||||
#[test]
|
||||
fn mul2_plus_bv_opt() {
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![BV_ADD;
|
||||
term![BV_MUL;
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
|
||||
@@ -900,7 +895,7 @@ mod test {
|
||||
|
||||
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)))
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let (max, vars) = ilp.solve(default_solver).unwrap();
|
||||
@@ -912,12 +907,11 @@ mod test {
|
||||
let a = leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)));
|
||||
let c = leaf_term(Op::Var("c".to_owned(), Sort::Bool));
|
||||
let cs = Computation {
|
||||
precomputes: Default::default(),
|
||||
outputs: vec![term![BV_ADD;
|
||||
term![ITE; c, bv_lit(2,4), bv_lit(1,4)],
|
||||
term![BV_MUL; a, bv_lit(2,4)]
|
||||
]],
|
||||
metadata: ComputationMetadata::default(),
|
||||
..Default::default()
|
||||
};
|
||||
let ilp = to_ilp(cs);
|
||||
let (max, vars) = ilp.solve(default_solver).unwrap();
|
||||
|
||||
2
util.py
2
util.py
@@ -5,7 +5,7 @@ from os import path
|
||||
feature_path = ".features.txt"
|
||||
mode_path = ".mode.txt"
|
||||
cargo_features = {"aby", "c", "lp", "r1cs", "kahip", "kahypar",
|
||||
"smt", "zok", "datalog", "bellman", "spartan"}
|
||||
"smt", "zok", "datalog", "bellman", "spartan", "poly"}
|
||||
|
||||
# Environment variables
|
||||
ABY_SOURCE = "./../ABY"
|
||||
|
||||
Reference in New Issue
Block a user