Add extension operators and new operators (#150)

modifies some opt passes w/ new ops
This commit is contained in:
Alex Ozdemir
2023-03-14 01:09:46 -07:00
committed by GitHub
parent 155794c2bf
commit 450c37b896
24 changed files with 1229 additions and 72 deletions

44
Cargo.lock generated
View File

@@ -291,6 +291,7 @@ dependencies = [
"rand_chacha 0.3.1",
"rsmt2",
"rug",
"rug-polynomial",
"serde",
"serde_bytes",
"serde_json",
@@ -655,6 +656,17 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "flint-sys"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd39a21fba2a1dd249af52d118f5b6cc8f1a98c87732de386dc5e8e53d22a24b"
dependencies = [
"gmp-mpfr-sys",
"libc",
"winapi",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -735,12 +747,12 @@ checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]]
name = "gmp-mpfr-sys"
version = "1.5.1"
version = "1.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a32868092c26fb25bb33c5ca7d8a937647979dfaa12f1f4f464beb57d726662c"
checksum = "781d258743e6e07d38c5bcd7468cfaa068d56789d10e8de0d3836d5bb338b5d7"
dependencies = [
"libc",
"windows-sys 0.42.0",
"winapi",
]
[[package]]
@@ -1403,9 +1415,9 @@ dependencies = [
[[package]]
name = "rug"
version = "1.19.1"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a465f6576b9f0844bd35749197576d68e3db169120532c4e0f868ecccad3d449"
checksum = "55313a5bab6820d1439c0266db37f8084e50618d35d16d81f692eee14d8b01b9"
dependencies = [
"az",
"gmp-mpfr-sys",
@@ -1413,6 +1425,28 @@ dependencies = [
"serde",
]
[[package]]
name = "rug-fft"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "333f0e04ec25ccaee0885a94377702d00d6fa631023464ff38bec2efe7247d90"
dependencies = [
"rug",
]
[[package]]
name = "rug-polynomial"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78837d2c84bd38a3c9e666f802e21837df9f46d8413f4787a823c9b448d392d"
dependencies = [
"flint-sys",
"gmp-mpfr-sys",
"rug",
"rug-fft",
"serde",
]
[[package]]
name = "rustc-demangle"
version = "0.1.21"

View File

@@ -23,6 +23,7 @@ typed-arena = { version = "2.0", optional = true }
log = "0.4"
thiserror = "1.0"
bellman = { git = "https://github.com/alex-ozdemir/bellman.git", branch = "mirage", optional = true }
rug-polynomial = { version = "0.2.5", optional = true }
ff = { version = "0.12", optional = true }
fxhash = "0.2"
good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false, optional = true }
@@ -70,6 +71,7 @@ aby = ["lp"]
kahip = ["aby"]
kahypar = ["aby"]
r1cs = []
poly = ["rug-polynomial"]
spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"]
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys", "byteorder"]

View File

@@ -27,8 +27,7 @@ fn main() {
}
let cs = Computation {
outputs: vec![term![Op::Eq; t, v]],
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
..Default::default()
};
let _assignment = ilp::assign(&cs, "hycc");
//dbg!(&assignment);

View File

@@ -310,6 +310,26 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T
_ => None,
}
}
Op::Array(k, v) => t
.cs()
.iter()
.map(|c| c_get(c).as_value_opt().cloned())
.collect::<Option<_>>()
.map(|cs| {
leaf_term(Op::Const(Value::Array(Array::from_vec(
k.clone(),
v.clone(),
cs,
))))
}),
Op::Fill(k, s) => c_get(&t.cs()[0]).as_value_opt().map(|v| {
leaf_term(Op::Const(Value::Array(Array::new(
k.clone(),
Box::new(v.clone()),
Default::default(),
*s,
))))
}),
Op::Select => match (get(0).as_array_opt(), get(1).as_value_opt()) {
(Some(arr), Some(idx)) => Some(leaf_term(Op::Const(arr.select(idx)))),
_ => None,

View File

@@ -37,6 +37,11 @@ impl RewritePass for Linearizer {
computation.extend_precomputation(new_name.clone(), precomp);
Some(leaf_term(Op::Var(new_name, new_sort)))
}
Op::Array(..) => Some(term(Op::Tuple, rewritten_children())),
Op::Fill(_, size) => Some(term(
Op::Tuple,
vec![rewritten_children().pop().unwrap(); *size],
)),
Op::Select => {
let cs = rewritten_children();
let idx = &cs[1];

View File

@@ -117,7 +117,7 @@ impl NonOblivComputer {
impl ProgressAnalysisPass for NonOblivComputer {
fn visit(&mut self, term: &Term) -> bool {
match &term.op() {
Op::Store => {
Op::Store | Op::CStore => {
let a = &term.cs()[0];
let i = &term.cs()[1];
let v = &term.cs()[2];
@@ -138,6 +138,32 @@ impl ProgressAnalysisPass for NonOblivComputer {
}
progress
}
Op::Array(..) => {
let mut progress = false;
if !term.cs().is_empty() {
if let Sort::Array(..) = check(&term.cs()[0]) {
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
for i in 0..term.cs().len() - 1 {
progress =
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
}
for i in (0..term.cs().len() - 1).rev() {
progress =
self.bi_implicate(&term.cs()[i], &term.cs()[i + 1]) || progress;
}
progress = self.bi_implicate(term, &term.cs()[0]) || progress;
}
}
progress
}
Op::Fill(..) => {
let v = &term.cs()[0];
if let Sort::Array(..) = check(v) {
self.bi_implicate(term, &term.cs()[0])
} else {
false
}
}
Op::Select => {
// Even though the selected value may not have array sort, we still flag it as
// non-oblivious so we know whether to replace it or not.
@@ -153,6 +179,13 @@ impl ProgressAnalysisPass for NonOblivComputer {
progress = self.bi_implicate(term, a) || progress;
progress
}
Op::Var(..) => {
if let Sort::Array(..) = check(term) {
self.mark(term)
} else {
false
}
}
Op::Ite => {
let t = &term.cs()[1];
let f = &term.cs()[2];
@@ -269,6 +302,32 @@ impl RewritePass for Replacer {
None
}
}
Op::CStore => {
if self.should_replace(orig) {
let mut cs = get_cs();
debug_assert_eq!(cs.len(), 4);
let cond = cs.remove(3);
let k_const = get_const(&cs.remove(1));
let orig = cs[0].clone();
Some(term![ITE; cond, term(Op::Update(k_const), cs), orig])
} else {
None
}
}
Op::Array(..) => {
if self.should_replace(orig) {
Some(term(Op::Tuple, get_cs()))
} else {
None
}
}
Op::Fill(_, size) => {
if self.should_replace(orig) {
Some(term(Op::Tuple, vec![get_cs().pop().unwrap(); *size]))
} else {
None
}
}
Op::Ite => {
if self.should_replace(orig) {
Some(term(Op::Ite, get_cs()))

View File

@@ -68,12 +68,18 @@ impl RewritePass for Pass {
_rewritten_children: F,
) -> Option<Term> {
if let Op::Var(name, sort) = &orig.op() {
let mut new_var_reqs = Vec::new();
let new = create_vars(name, orig.clone(), sort, &mut new_var_reqs, true);
for (name, term) in new_var_reqs {
computation.extend_precomputation(name, term);
debug!("Considering var: {}", name);
if !computation.metadata.lookup(name).committed {
let mut new_var_reqs = Vec::new();
let new = create_vars(name, orig.clone(), sort, &mut new_var_reqs, true);
for (name, term) in new_var_reqs {
computation.extend_precomputation(name, term);
}
Some(new)
} else {
debug!("Skipping b/c it is commited.");
None
}
Some(new)
} else {
None
}
@@ -89,25 +95,28 @@ pub fn scalarize_inputs(cs: &mut Computation) {
remove_non_scalar_vars_from_main_computation(cs);
}
/// Check that every variables is a scalar.
/// Check that every variables is a scalar (or committed)
pub fn assert_all_vars_are_scalars(cs: &Computation) {
for t in cs.terms_postorder() {
if let Op::Var(_name, sort) = &t.op() {
match sort {
Sort::Array(..) | Sort::Tuple(..) => {
panic!("Variable {} is non-scalar", t);
if let Op::Var(name, sort) = &t.op() {
if !cs.metadata.lookup(name).committed {
match sort {
Sort::Array(..) | Sort::Tuple(..) => {
panic!("Variable {} is non-scalar", t);
}
_ => {}
}
_ => {}
}
}
}
}
/// Check that every variables is a scalar.
/// Remove all variables that are non-scalar (and not committed).
fn remove_non_scalar_vars_from_main_computation(cs: &mut Computation) {
for input in cs.metadata.ordered_inputs() {
if !check(&input).is_scalar() {
cs.metadata.remove_var(input.as_var_name());
let name = input.as_var_name();
if !check(&input).is_scalar() && !cs.metadata.lookup(&name).committed {
cs.metadata.remove_var(name);
}
}
assert_all_vars_are_scalars(cs);

View File

@@ -106,6 +106,17 @@ impl TupleTree {
fn bimap(&self, mut f: impl FnMut(Term, Term) -> Term, other: &Self) -> Self {
self.structure(itertools::zip_eq(self.flatten(), other.flatten()).map(|(a, b)| f(a, b)))
}
fn transpose_map(vs: Vec<Self>, f: impl FnMut(Vec<Term>) -> Term) -> Self {
assert!(!vs.is_empty());
let n = vs[0].flatten().count();
let mut ts = vec![Vec::new(); n];
for v in &vs {
for (i, t) in v.flatten().enumerate() {
ts[i].push(t);
}
}
vs[0].structure(ts.into_iter().map(f))
}
fn get(&self, i: usize) -> Self {
match self {
TupleTree::NonTuple(cs) => {
@@ -250,6 +261,15 @@ pub fn eliminate_tuples(cs: &mut Computation) {
debug_assert!(cs.is_empty());
a.bimap(|a, v| term![Op::Store; a, i.clone(), v], &v)
}
Op::Array(k, _v) => TupleTree::transpose_map(cs, |children| {
assert!(!children.is_empty());
let v_s = check(&children[0]);
term(Op::Array(k.clone(), v_s), children)
}),
Op::Fill(key_sort, size) => {
let values = cs.pop().unwrap();
values.map(|v| term![Op::Fill(key_sort.clone(), *size); v])
}
Op::Select => {
let i = cs.pop().unwrap().unwrap_non_tuple();
let a = cs.pop().unwrap();

View File

@@ -1,5 +1,7 @@
use crate::ir::term::*;
use log::trace;
/// A rewriting pass.
pub trait RewritePass {
/// Visit (and possibly rewrite) a term.
@@ -11,11 +13,49 @@ pub trait RewritePass {
orig: &Term,
rewritten_children: F,
) -> Option<Term>;
/// What
fn visit_cache(
&mut self,
computation: &mut Computation,
orig: &Term,
cache: &TermMap<Term>,
) -> Option<Term> {
let get_children = || -> Vec<Term> {
orig.cs()
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect()
};
self.visit(computation, orig, get_children)
}
fn traverse(&mut self, computation: &mut Computation) {
self.traverse_full(computation, false, true);
}
fn traverse_full(
&mut self,
computation: &mut Computation,
precompute: bool,
persistent_arrays: bool,
) {
let mut cache = TermMap::<Term>::default();
let mut children_added = TermSet::default();
let mut stack = Vec::new();
if persistent_arrays {
stack.extend(
computation
.persistent_arrays
.iter()
.map(|(_name, final_term)| final_term.clone()),
);
}
stack.extend(computation.outputs.iter().cloned());
if precompute {
stack.extend(computation.precomputes.outputs().values().cloned());
}
while let Some(top) = stack.pop() {
if !cache.contains_key(&top) {
// was it missing?
@@ -23,24 +63,40 @@ pub trait RewritePass {
stack.push(top.clone());
stack.extend(top.cs().iter().filter(|c| !cache.contains_key(c)).cloned());
} else {
let get_children = || -> Vec<Term> {
top.cs()
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect()
};
let new_t_opt = self.visit(computation, &top, get_children);
let new_t = new_t_opt.unwrap_or_else(|| term(top.op().clone(), get_children()));
let new_t_opt = self.visit_cache(computation, &top, &cache);
let new_t = new_t_opt.unwrap_or_else(|| {
term(
top.op().clone(),
top.cs()
.iter()
.map(|c| cache.get(c).unwrap())
.cloned()
.collect(),
)
});
cache.insert(top.clone(), new_t);
}
}
}
if persistent_arrays {
for (_name, final_term) in &mut computation.persistent_arrays {
let new_final_term = cache.get(final_term).unwrap().clone();
trace!("Array {} -> {}", final_term, new_final_term);
*final_term = new_final_term;
}
}
computation.outputs = computation
.outputs
.iter()
.map(|o| cache.get(o).unwrap().clone())
.collect();
if precompute {
let os = computation.precomputes.outputs().clone();
for (name, old_term) in os {
let new_term = cache.get(&old_term).unwrap().clone();
computation.precomputes.change_output(&name, new_term);
}
}
}
}

View File

@@ -76,6 +76,7 @@ impl Constraints for Computation {
outputs: assertions,
metadata,
precomputes: Default::default(),
persistent_arrays: Default::default(),
}
}
}

View File

@@ -0,0 +1,66 @@
//! IR extensions
use super::ty::TypeErrorReason;
use super::{Sort, Term, Value};
use circ_hc::Node;
use serde::{Deserialize, Serialize};
mod poly;
mod ram;
mod sort;
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
/// An extension operator. Not externally supported.
///
/// Often evaluatable, but not compilable.
pub enum ExtOp {
/// See [ram::eval]
PersistentRamSplit,
/// Given an array of tuples, returns a reordering such that the result is sorted.
Sort,
/// See [poly].
UniqDeriGcd,
}
impl ExtOp {
/// Its arity
pub fn arity(&self) -> Option<usize> {
match self {
ExtOp::PersistentRamSplit => Some(2),
ExtOp::Sort => Some(2),
ExtOp::UniqDeriGcd => Some(1),
}
}
/// Type-check, given argument sorts
pub fn check(&self, arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
match self {
ExtOp::PersistentRamSplit => ram::check(arg_sorts),
ExtOp::Sort => sort::check(arg_sorts),
ExtOp::UniqDeriGcd => poly::check(arg_sorts),
}
}
/// Evaluate, given argument values
pub fn eval(&self, args: &[&Value]) -> Value {
match self {
ExtOp::PersistentRamSplit => ram::eval(args),
ExtOp::Sort => sort::eval(args),
ExtOp::UniqDeriGcd => poly::eval(args),
}
}
/// Indicate which children of `t` must be typed to type `t`.
pub fn check_dependencies(&self, t: &Term) -> Vec<Term> {
t.cs().to_vec()
}
/// Parse, from bytes.
pub fn parse(bytes: &[u8]) -> Option<Self> {
match bytes {
b"persistent_ram_split" => Some(ExtOp::PersistentRamSplit),
b"uniq_deri_gcd" => Some(ExtOp::UniqDeriGcd),
b"sort" => Some(ExtOp::Sort),
_ => None,
}
}
}
#[cfg(test)]
mod test;

75
src/ir/term/ext/poly.rs Normal file
View File

@@ -0,0 +1,75 @@
//! Operator UniqDeriGcd
//!
//! Given an array of (root, cond) tuples (root is a field element, cond is a boolean),
//! define f(X) = prod_i (cond_i ? X - root_i : 1)
//!
//! Compute f'(X) and s,t s.t. fs + f't = 1. Return an array of coefficients for s and one for t
//! (as a tuple).
use crate::ir::term::ty::*;
use crate::ir::term::*;
/// Type-check [super::ExtOp::UniqDeriGcd].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[pairs] = arg_sorts {
let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?;
let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?;
let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?;
if let &[root, cond] = &value_tup {
eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?;
eq_or(
cond,
&Sort::Bool,
"UniqDeriGcd pairs: second element must be a bool",
)?;
let box_f = Box::new(f.clone());
let arr = Sort::Array(box_f.clone(), box_f, size);
Ok(Sort::Tuple(Box::new([arr.clone(), arr])))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
"UniqDeriGcd: pairs value must be a pair".into(),
))
}
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
}
/// Evaluate [super::ExtOp::UniqDeriGcd].
#[cfg(feature = "poly")]
pub fn eval(args: &[&Value]) -> Value {
use rug_polynomial::ModPoly;
let sort = args[0].sort().as_array().0.clone();
let field = sort.as_pf().clone();
let mut roots: Vec<Integer> = Vec::new();
let deg = args[0].as_array().size;
for t in args[0].as_array().values() {
let tuple = t.as_tuple();
let cond = tuple[1].as_bool();
if cond {
roots.push(tuple[0].as_pf().i());
}
}
let p = ModPoly::with_roots(roots, field.modulus());
let dp = p.derivative();
let (g, s, t) = p.xgcd(&dp);
assert_eq!(g.len(), 1);
assert_eq!(g.get_coefficient(0), 1);
let coeff_arr = |s: ModPoly| {
let v: Vec<Value> = (0..deg)
.map(|i| Value::Field(field.new_v(s.get_coefficient(i))))
.collect();
Value::Array(Array::from_vec(sort.clone(), sort.clone(), v))
};
let s_cs = coeff_arr(s);
let t_cs = coeff_arr(t);
Value::Tuple(Box::new([s_cs, t_cs]))
}
/// Evaluate [super::ExtOp::UniqDeriGcd].
#[cfg(not(feature = "poly"))]
pub fn eval(_args: &[&Value]) -> Value {
panic!("Cannot evalute Op::UniqDeriGcd without 'poly' feature")
}

126
src/ir/term/ext/ram.rs Normal file
View File

@@ -0,0 +1,126 @@
//! Operator PersistentRamSplit
use crate::ir::term::ty::*;
use crate::ir::term::*;
use fxhash::FxHashSet as HashSet;
/// Type-check [super::ExtOp::PersistentRamSplit].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[entries, indices] = arg_sorts {
let (key, value, size) = ty::array_or(entries, "PersistentRamSplit entries")?;
let f = pf_or(key, "PersistentRamSplit entries: indices must be field")?;
let value_tup = ty::tuple_or(value, "PersistentRamSplit entries: value must be a tuple")?;
if let &[old, new] = &value_tup {
eq_or(
f,
old,
"PersistentRamSplit entries: value must be a field pair",
)?;
eq_or(
f,
new,
"PersistentRamSplit entries: value must be a field pair",
)?;
let (i_key, i_value, i_size) = ty::array_or(indices, "PersistentRamSplit indices")?;
eq_or(f, i_key, "PersistentRamSplit indices: key must be a field")?;
eq_or(
f,
i_value,
"PersistentRamSplit indices: value must be a field",
)?;
let n_touched = i_size.min(size);
let n_ignored = size - n_touched;
let box_f = Box::new(f.clone());
let f_pair = Sort::Tuple(Box::new([f.clone(), f.clone()]));
let ignored_entries_sort =
Sort::Array(box_f.clone(), Box::new(f_pair.clone()), n_ignored);
let selected_entries_sort = Sort::Array(box_f, Box::new(f_pair), n_touched);
Ok(Sort::Tuple(Box::new([
ignored_entries_sort,
selected_entries_sort.clone(),
selected_entries_sort,
])))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
"PersistentRamSplit: entries value must be a pair".into(),
))
}
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
}
/// Evaluate [super::ExtOp::PersistentRamSplit].
///
/// Takes two arguments:
///
/// * entries: [(val_init_i, val_fin_i)] (len E)
/// * indices: [idx_i] (len I)
///
/// assumes I < E and 0 <= idx_i < E.
///
/// Let dedup_i be idx_i with duplicates removed
/// Let ext_i be dedup_i padded up to length I. The added elements are each i in 0.. (so long as i hasn't occured in ext_i already).
/// Let
///
///
/// Returns:
/// * a bunch of sequences of index-value pairs:
/// * untouched_entries (array field (tuple (field field)) (length I - E))
/// * init_reads (array field (tuple (field field)) (length I))
/// * fin_writes (array field (tuple (field field)) (length I))
pub fn eval(args: &[&Value]) -> Value {
let entries = &args[0].as_array().values();
let (init_vals, fin_vals): (Vec<Value>, Vec<Value>) = entries
.iter()
.map(|t| (t.as_tuple()[0].clone(), t.as_tuple()[1].clone()))
.unzip();
let indices = &args[1].as_array().values();
let num_accesses = indices.len();
let field = args[0].as_array().key_sort.as_pf();
let uniq_indices = {
let mut uniq_indices = Vec::<usize>::new();
let mut used_indices = HashSet::<usize>::default();
for i in indices.iter().map(|i| i.as_usize().unwrap()).chain(0..) {
if uniq_indices.len() == num_accesses {
break;
}
if !used_indices.contains(&i) {
uniq_indices.push(i);
used_indices.insert(i);
}
}
uniq_indices.sort();
uniq_indices
};
let mut init_reads = Vec::new();
let mut fin_writes = Vec::new();
let mut untouched_entries = Vec::new();
let mut j = 0;
for (i, (init_val, fin_val)) in init_vals.into_iter().zip(fin_vals).enumerate() {
if j < uniq_indices.len() && uniq_indices[j] == i {
init_reads.push((i, init_val));
fin_writes.push((i, fin_val));
j += 1;
} else {
untouched_entries.push((i, init_val));
}
}
let key_sort = Sort::Field(field.clone());
let entry_to_vals =
|e: (usize, Value)| Value::Tuple(Box::new([Value::Field(field.new_v(e.0)), e.1]));
let vec_to_arr = |v: Vec<(usize, Value)>| {
let vals: Vec<Value> = v.into_iter().map(entry_to_vals).collect();
Value::Array(Array::from_vec(
key_sort.clone(),
vals.first().unwrap().sort(),
vals,
))
};
let init_reads = vec_to_arr(init_reads);
let untouched_entries = vec_to_arr(untouched_entries);
let fin_writes = vec_to_arr(fin_writes);
Value::Tuple(vec![untouched_entries, init_reads, fin_writes].into_boxed_slice())
}

22
src/ir/term/ext/sort.rs Normal file
View File

@@ -0,0 +1,22 @@
//! Sort operator
use crate::ir::term::ty::*;
use crate::ir::term::*;
/// Type-check [super::ExtOp::Sort].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
array_or(arg_sorts[0], "sort argument").map(|_| arg_sorts[0].clone())
}
/// Evaluate [super::ExtOp::Sort].
pub fn eval(args: &[&Value]) -> Value {
let sort = args[0].sort();
let (key_sort, value_sort, _) = sort.as_array();
let mut values: Vec<Value> = args[0].as_array().values();
values.sort();
Value::Array(Array::from_vec(
key_sort.clone(),
value_sort.clone(),
values,
))
}

151
src/ir/term/ext/test.rs Normal file
View File

@@ -0,0 +1,151 @@
//! Contains a number of constant terms for testing
#![allow(missing_docs)]
use super::super::*;
#[test]
fn op_sort_eval() {
let t = text::parse_term(b"(declare () (sort (#l (mod 11) ((#t true true) (#t true false)))))");
let actual_output = eval(&t, &Default::default());
let expected_output = text::parse_value_map(
b"(let ((output (#l (mod 11) ((#t true false) (#t true true))))) false)",
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
let t = text::parse_term(b"(declare () (sort (#l (mod 11) (#x0 #xf #x4 #x3))))");
let actual_output = eval(&t, &Default::default());
let expected_output =
text::parse_value_map(b"(let ((output (#l (mod 11) (#x0 #x3 #x4 #xf)))) false)");
assert_eq!(&actual_output, expected_output.get("output").unwrap());
}
#[cfg(feature = "poly")]
#[test]
fn uniq_deri_gcd_eval() {
let t = text::parse_term(
b"
(declare (
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
)
(uniq_deri_gcd pairs))",
);
let inputs = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
) false))
",
);
let actual_output = eval(&t, &inputs);
let expected_output = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(output (#t
(#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage
))
) false))
",
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
let inputs = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
) false))
",
);
let actual_output = eval(&t, &inputs);
let expected_output = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(output (#t
(#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage
))
) false))
",
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
}
#[test]
fn persistent_ram_split_eval() {
let t = text::parse_term(
b"
(declare (
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
(indices (array (mod 17) (mod 17) 3))
)
(persistent_ram_split entries indices))",
);
let inputs = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(entries (#l (mod 17) ( (#t #f0 #f1) (#t #f1 #f1) (#t #f2 #f3) (#t #f3 #f4) (#t #f4 #f4) )))
(indices (#l (mod 17) (#f0 #f2 #f3)))
) false))
",
);
let actual_output = eval(&t, &inputs);
let expected_output = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(output (#t
(#l (mod 17) ( (#t #f1 #f1) (#t #f4 #f4) )) ; untouched
(#l (mod 17) ( (#t #f0 #f0) (#t #f2 #f2) (#t #f3 #f3) )) ; init_reads
(#l (mod 17) ( (#t #f0 #f1) (#t #f2 #f3) (#t #f3 #f4) )) ; fin_writes
))
) false))
",
);
dbg!(&actual_output.as_tuple()[0].as_array().default);
dbg!(
&expected_output.get("output").unwrap().as_tuple()[0]
.as_array()
.default
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
// duplicates
let inputs = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(entries (#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) (#t #f3 #f3) (#t #f4 #f4) )))
(indices (#l (mod 17) (#f1 #f1 #f1)))
) false))
",
);
let actual_output = eval(&t, &inputs);
let expected_output = text::parse_value_map(
b"
(set_default_modulus 17
(let
(
(output (#t
(#l (mod 17) ( (#t #f3 #f3) (#t #f4 #f4) )) ; untouched
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f1) (#t #f2 #f2) )) ; init_reads
(#l (mod 17) ( (#t #f0 #f0) (#t #f1 #f2) (#t #f2 #f2) )) ; fin_writes
))
) false))
",
);
assert_eq!(&actual_output, expected_output.get("output").unwrap());
}

View File

@@ -1,6 +1,6 @@
//! Machinery for formatting IR types
use super::{
Array, ComputationMetadata, Node, Op, PartyId, PostOrderIter, Sort, Term, TermMap, Value,
ext, Array, ComputationMetadata, Node, Op, PartyId, PostOrderIter, Sort, Term, TermMap, Value,
VariableMetadata,
};
use crate::cfg::{cfg, is_cfg_set};
@@ -186,8 +186,22 @@ impl DisplayIr for Op {
Op::IntBinPred(a) => write!(f, "{a}"),
Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()),
Op::PfChallenge(n, m) => write!(f, "(challenge {} {})", n, m.modulus()),
Op::PfFitsInBits(n) => write!(f, "(pf_fits_in_bits {})", n),
Op::Select => write!(f, "select"),
Op::Store => write!(f, "store"),
Op::CStore => write!(f, "cstore"),
Op::Fill(key_sort, size) => {
write!(f, "(fill ")?;
key_sort.ir_fmt(f)?;
write!(f, " {})", *size)
}
Op::Array(k, v) => {
write!(f, "(array ")?;
k.ir_fmt(f)?;
write!(f, " ")?;
v.ir_fmt(f)?;
write!(f, ")")
}
Op::Tuple => write!(f, "tuple"),
Op::Field(i) => write!(f, "(field {i})"),
Op::Update(i) => write!(f, "(update {i})"),
@@ -198,6 +212,18 @@ impl DisplayIr for Op {
}
Op::Call(name, _, _) => write!(f, "fn:{name}"),
Op::Rot(i) => write!(f, "(rot {i})"),
Op::PfToBoolTrusted => write!(f, "pf2bool_trusted"),
Op::ExtOp(o) => o.ir_fmt(f),
}
}
}
impl DisplayIr for ext::ExtOp {
fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult {
match self {
ext::ExtOp::PersistentRamSplit => write!(f, "persistent_ram_split"),
ext::ExtOp::UniqDeriGcd => write!(f, "uniq_deri_gcd"),
ext::ExtOp::Sort => write!(f, "sort"),
}
}
}

View File

@@ -26,17 +26,17 @@ use circ_fields::{FieldT, FieldV};
pub use circ_hc::{Node, Table, Weak};
use circ_opt::FieldToBv;
use fxhash::{FxHashMap, FxHashSet};
use log::debug;
use log::{debug, trace};
use rug::Integer;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
use std::cell::Cell;
use std::collections::BTreeMap;
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
use std::sync::Arc;
pub mod bv;
pub mod dist;
pub mod ext;
pub mod extras;
pub mod fmt;
pub mod lin;
@@ -46,6 +46,7 @@ pub mod text;
pub mod ty;
pub use bv::BitVector;
pub use ext::ExtOp;
pub use ty::{check, check_rec, TypeError, TypeErrorReason};
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -132,6 +133,8 @@ pub enum Op {
///
/// In IR evaluation, we sample deterministically based on a hash of the name.
PfChallenge(String, FieldT),
/// Requires the input pf element to fit in this many (unsigned) bits.
PfFitsInBits(usize),
/// Integer n-ary operator
IntNaryOp(IntNaryOp),
@@ -146,6 +149,15 @@ pub enum Op {
///
/// Makes an array equal to `array`, but with `value` at `index`.
Store,
/// Quad-operator, with arguments (array, index, value, cond).
///
/// If `cond`, outputs an array equal to `array`, but with `value` at `index`.
/// Otherwise, oupputs `array`.
CStore,
/// Makes an array of the indicated key sort with the indicated size, filled with the argument.
Fill(Sort, usize),
/// Create an array from (contiguous) values.
Array(Sort, Sort),
/// Assemble n things into a tuple
Tuple,
@@ -163,6 +175,12 @@ pub enum Op {
/// Cyclic right rotation of an array
/// i.e. (Rot(1) [1,2,3,4]) --> ([4,1,2,3])
Rot(usize),
/// Assume that the field element is 0 or 1, and treat it as a boolean.
PfToBoolTrusted,
/// Extension operators. Used in compilation, but not externally supported
ExtOp(ext::ExtOp),
}
/// Boolean AND
@@ -280,17 +298,23 @@ impl Op {
Op::PfUnOp(_) => Some(1),
Op::PfNaryOp(_) => None,
Op::PfChallenge(_, _) => None,
Op::PfFitsInBits(..) => Some(1),
Op::IntNaryOp(_) => None,
Op::IntBinPred(_) => Some(2),
Op::UbvToPf(_) => Some(1),
Op::Select => Some(2),
Op::Store => Some(3),
Op::CStore => Some(4),
Op::Fill(..) => Some(1),
Op::Array(..) => None,
Op::Tuple => None,
Op::Field(_) => Some(1),
Op::Update(_) => Some(2),
Op::Map(op) => op.arity(),
Op::Call(_, args, _) => Some(args.len()),
Op::Rot(_) => Some(1),
Op::ExtOp(o) => o.arity(),
Op::PfToBoolTrusted => Some(1),
}
}
}
@@ -843,9 +867,9 @@ impl Sort {
#[track_caller]
/// Unwrap the modulus of this prime field, panicking otherwise.
pub fn as_pf(&self) -> Arc<Integer> {
pub fn as_pf(&self) -> &FieldT {
if let Sort::Field(fty) = self {
fty.modulus_arc()
fty
} else {
panic!("{} is not a field", self)
}
@@ -861,6 +885,42 @@ impl Sort {
}
}
#[track_caller]
/// Unwrap the constituent sorts of this array, panicking otherwise.
pub fn as_array(&self) -> (&Sort, &Sort, usize) {
if let Sort::Array(k, v, s) = self {
(k, v, *s)
} else {
panic!("{} is not an array", self)
}
}
/// Is this an array?
pub fn is_array(&self) -> bool {
matches!(self, Sort::Array(..))
}
/// The nth element of this sort.
/// Only defined for booleans, bit-vectors, and field elements.
#[track_caller]
pub fn nth_elem(&self, n: usize) -> Term {
match self {
Sort::Bool => {
assert!(n < 2);
bool_lit([false, true][n])
}
Sort::BitVector(w) => {
assert!(n < (1 << w));
bv_lit(n, *w)
}
Sort::Field(f) => {
debug_assert!(&Integer::from(n) < f.modulus());
pf_lit(f.new_v(n))
}
_ => panic!("Can't get nth element of sort {}", self),
}
}
/// An iterator over the elements of this sort (as IR Terms).
/// Only defined for booleans, bit-vectors, and field elements.
#[track_caller]
@@ -1238,8 +1298,15 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
/// Helper function for eval function. Handles a single term
fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, t: Term) -> Value {
let args: Vec<&Value> = t.cs().iter().map(|c| vs.get(c).unwrap()).collect();
trace!("Eval {} on {:?}", t.op(), args);
let v = eval_op(t.op(), &args, h);
debug!("Eval {}\nAs {}", t, v);
trace!("=> {}", v);
if let Value::Bool(false) = &v {
trace!("term {}", t);
for v in extras::free_variables(t.clone()) {
trace!(" {} = {}", v, h.get(&v).unwrap());
}
}
vs.insert(t, v.clone());
v
}
@@ -1393,6 +1460,9 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
}),
Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())),
Op::PfChallenge(name, field) => Value::Field(pf_challenge(name, field)),
Op::PfFitsInBits(n_bits) => {
Value::Bool(args[0].as_pf().i().signed_bits() <= *n_bits as u32)
}
// tuple
Op::Tuple => Value::Tuple(args.iter().map(|a| (*a).clone()).collect()),
Op::Field(i) => {
@@ -1415,6 +1485,31 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
let v = args[2].clone();
Value::Array(a.store(i, v))
}
Op::CStore => {
let a = args[0].as_array().clone();
let i = args[1].clone();
let v = args[2].clone();
let c = args[3].as_bool();
if c {
Value::Array(a.store(i, v))
} else {
Value::Array(a)
}
}
Op::Fill(key_sort, size) => {
let v = args[0].clone();
Value::Array(Array::new(
key_sort.clone(),
Box::new(v),
Default::default(),
*size,
))
}
Op::Array(key, value) => Value::Array(Array::from_vec(
key.clone(),
value.clone(),
args.iter().cloned().cloned().collect(),
)),
Op::Select => {
let a = args[0].as_array().clone();
let i = args[1];
@@ -1476,6 +1571,12 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
}
Value::Array(res)
}
Op::PfToBoolTrusted => {
let v = args[0].as_pf().i();
assert!(v == 0 || v == 1);
Value::Bool(v == 1)
}
Op::ExtOp(o) => o.eval(args),
o => unimplemented!("eval: {:?}", o),
}
@@ -1488,10 +1589,22 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
/// * a key sort, as all arrays do. This sort must be iterable (i.e., bool, int, bit-vector, or field).
/// * a value sort, for the array's default
pub fn make_array(key_sort: Sort, value_sort: Sort, i: Vec<Term>) -> Term {
let d = Sort::Array(Box::new(key_sort.clone()), Box::new(value_sort), i.len()).default_term();
i.into_iter()
.zip(key_sort.elems_iter())
.fold(d, |arr, (val, idx)| term(Op::Store, vec![arr, idx, val]))
term(Op::Array(key_sort, value_sort), i)
}
/// Make a sequence of terms from an array.
///
/// Requires
///
/// * an array term
pub fn unmake_array(a: Term) -> Vec<Term> {
let sort = check(&a);
let (key_sort, _, size) = sort.as_array();
key_sort
.elems_iter()
.take(size)
.map(|idx| term(Op::Select, vec![a.clone(), idx]))
.collect()
}
/// Make a term with no arguments, just an operator.
@@ -1547,6 +1660,29 @@ macro_rules! term {
};
}
#[macro_export]
/// Make a term, with clones.
///
/// Syntax:
///
/// * without children: `term![OP]`
/// * with children: `term![OP; ARG0, ARG1, ... ]`
/// * Note the semi-colon
macro_rules! term_c {
($x:expr; $($y:expr),+) => {
{
let mut args = Vec::new();
#[allow(clippy::vec_init_then_push)]
{
$(
args.push(($y).clone());
)+
}
term($x, args)
}
};
}
/// Map from terms
pub type TermMap<T> = FxHashMap<Term, T>;
/// LRU cache of terms (like TermMap, but limited size)
@@ -1703,14 +1839,27 @@ impl ComputationMetadata {
self.vars.insert(metadata.name.clone(), metadata);
}
/// Lookup metadata
#[track_caller]
fn lookup<Q: std::borrow::Borrow<str> + ?Sized>(&self, name: &Q) -> &VariableMetadata {
pub fn lookup<Q: std::borrow::Borrow<str> + ?Sized>(&self, name: &Q) -> &VariableMetadata {
let n = name.borrow();
self.vars
.get(n)
.unwrap_or_else(|| panic!("Missing input {} in inputs{:#?}", n, self.vars))
}
/// Lookup metadata
#[track_caller]
pub fn lookup_mut<Q: std::borrow::Borrow<str> + ?Sized>(
&mut self,
name: &Q,
) -> &mut VariableMetadata {
let n = name.borrow();
self.vars
.get_mut(n)
.unwrap_or_else(|| panic!("Missing input {}", n))
}
/// Returns None if the value is public. Otherwise, the unique party that knows it.
pub fn get_input_visibility(&self, input_name: &str) -> Option<PartyId> {
self.lookup(input_name).vis
@@ -1890,6 +2039,11 @@ pub struct Computation {
pub metadata: ComputationMetadata,
/// Pre-computations
pub precomputes: precomp::PreComp,
/// Persistent Arrays. [(name, term)]
/// where:
/// * name: a variable name (array type) indicating the input state
/// * name: a term indicating the output state
pub persistent_arrays: Vec<(String, Term)>,
}
impl Computation {
@@ -1911,7 +2065,13 @@ impl Computation {
party: Option<PartyId>,
precompute: Option<Term>,
) -> Term {
debug!("Var: {} : {} (visibility: {:?})", name, s, party);
debug!(
"Var: {} : {} (visibility: {:?}) (precompute: {})",
name,
s,
party,
precompute.is_some()
);
self.metadata.new_input(name.to_owned(), party, s.clone());
if let Some(p) = precompute {
assert_eq!(&s, &check(&p));
@@ -1920,6 +2080,31 @@ impl Computation {
leaf_term(Op::Var(name.to_owned(), s))
}
/// Create a new variable with the given metadata.
///
/// If `precompute` is set, that precomputation is added to give a value for this variable.
/// Otherwise, the variable is assumed to be an input.
pub fn new_var_metadata(
&mut self,
metadata: VariableMetadata,
precompute: Option<Term>,
) -> Term {
debug!(
"Var: {} : {:?} (precompute {})",
metadata.name,
metadata,
precompute.is_some()
);
let sort = metadata.sort.clone();
let name = metadata.name.clone();
self.metadata.new_input_from_meta(metadata);
if let Some(p) = precompute {
assert_eq!(&sort, &check(&p));
self.precomputes.add_output(name.clone(), p);
}
leaf_term(Op::Var(name, sort))
}
/// Add a new input `new_input_var` to this computation,
/// whose value is determined by `precomp`: a term over existing inputs.
///
@@ -1946,6 +2131,49 @@ impl Computation {
self.new_var(&new_input_var, sort, vis, Some(precomp));
}
/// Intialize a new persistent array.
pub fn start_persistent_array(
&mut self,
var: &str,
size: usize,
field: FieldT,
party: PartyId,
) -> Term {
let f = Sort::Field(field);
let s = Sort::Array(Box::new(f.clone()), Box::new(f), size);
let md = VariableMetadata {
name: var.to_owned(),
vis: Some(party),
sort: s,
committed: true,
..Default::default()
};
let term = self.new_var_metadata(md, None);
// we'll replace dummy later
let dummy = bool_lit(true);
self.persistent_arrays.push((var.into(), dummy));
term
}
/// Record the final state of a persistent array. Should be called once per array, with the
/// same name as [Computation::start_persistent_array].
pub fn end_persistent_array(&mut self, var: &str, final_state: Term) {
for (name, t) in &mut self.persistent_arrays {
if name == var {
assert_eq!(*t, bool_lit(true));
*t = final_state;
return;
}
}
panic!("No existing persistent memory {}", var)
}
/// Make a vector of existing variables a commitment.
pub fn add_commitment(&mut self, names: Vec<String>) {
self.metadata.add_commitment(names);
}
/// Change the sort of a variables
pub fn remove_var(&mut self, var: &str) {
self.metadata.remove_var(var);
@@ -1964,6 +2192,7 @@ impl Computation {
outputs: Vec::new(),
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
persistent_arrays: Default::default(),
}
}
@@ -1992,6 +2221,29 @@ impl Computation {
terms.pop();
terms.into_iter()
}
/// Evaluate the precompute, then this computation.
pub fn eval_all(&self, values: &FxHashMap<String, Value>) -> Vec<Value> {
let mut values = values.clone();
// set all challenges to 1.
for v in self.metadata.vars.values() {
if v.random {
let field = v.sort.as_pf();
let value = Value::Field(pf_challenge(&v.name, field));
values.insert(v.name.clone(), value);
}
}
values = self.precomputes.eval(&values);
let mut cache = Default::default();
let mut outputs = Vec::new();
for o in &self.outputs {
outputs.push(eval_cached(o, &values, &mut cache).clone());
}
outputs
}
}
#[derive(Clone, Debug, Default)]

View File

@@ -7,6 +7,8 @@ use fxhash::{FxHashMap, FxHashSet};
use crate::ir::term::*;
use log::trace;
/// A "precomputation".
///
/// Expresses a computation to be run in advance by a single party.
@@ -47,6 +49,10 @@ impl PreComp {
let old = self.outputs.insert(name, value);
assert!(old.is_none());
}
/// Overwrite a step
pub fn change_output(&mut self, name: &str, value: Term) {
*self.outputs.get_mut(name).unwrap() = value;
}
/// Retain only the parts of this precomputation that can be evaluated from
/// the `known` inputs.
pub fn restrict_to_inputs(&mut self, known: FxHashSet<String>) {
@@ -84,7 +90,9 @@ impl PreComp {
for (o_name, _o_sort) in &self.sequence {
let o = self.outputs.get(o_name).unwrap();
eval_cached(o, &env, &mut value_cache);
env.insert(o_name.clone(), value_cache.get(o).unwrap().clone());
let value = value_cache.get(o).unwrap().clone();
trace!("pre {o_name} => {value}");
env.insert(o_name.clone(), value);
}
env
}

View File

@@ -626,3 +626,11 @@ pub fn bv_sub_tests() -> Vec<Term> {
],
]
}
#[test]
fn pf2bool_eval() {
let t = text::parse_term(b"(declare () (pf2bool_trusted (ite false #f1m11 #f0m11)))");
let actual_output = eval(&t, &Default::default());
let expected_output = text::parse_value_map(b"(let ((output false)) false)");
assert_eq!(&actual_output, expected_output.get("output").unwrap());
}

View File

@@ -28,6 +28,11 @@
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
//! * OUTPUTS is `((X1 S1) .. (Xn Sn))`
//! * TUPLE_TERM is a tuple of the same arity as the output
//! * ARRAYS (optional): `(persistent_arrays ARRAY*)`:
//! * ARRAY is `(X S T)`
//! * X is the name of the inital state
//! * S is the size
//! * T is the state (final)
//! * Sort `S`:
//! * `bool`
//! * `f32`
@@ -286,12 +291,22 @@ impl<'src> IrInterp<'src> {
Leaf(Ident, b"intmul") => Ok(INT_MUL),
Leaf(Ident, b"select") => Ok(Op::Select),
Leaf(Ident, b"store") => Ok(Op::Store),
Leaf(Ident, b"cstore") => Ok(Op::CStore),
Leaf(Ident, b"tuple") => Ok(Op::Tuple),
Leaf(Ident, b"pf2bool_trusted") => Ok(Op::PfToBoolTrusted),
Leaf(Ident, bytes) => {
if let Some(e) = ext::ExtOp::parse(bytes) {
Ok(Op::ExtOp(e))
} else {
todo!("Unparsed op: {}", tt)
}
}
List(tts) => match &tts[..] {
[Leaf(Ident, b"extract"), a, b] => Ok(Op::BvExtract(self.usize(a), self.usize(b))),
[Leaf(Ident, b"uext"), a] => Ok(Op::BvUext(self.usize(a))),
[Leaf(Ident, b"sext"), a] => Ok(Op::BvSext(self.usize(a))),
[Leaf(Ident, b"pf2bv"), a] => Ok(Op::PfToBv(self.usize(a))),
[Leaf(Ident, b"pf_fits_in_bits"), a] => Ok(Op::PfFitsInBits(self.usize(a))),
[Leaf(Ident, b"bit"), a] => Ok(Op::BvBit(self.usize(a))),
[Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))),
[Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))),
@@ -300,9 +315,13 @@ impl<'src> IrInterp<'src> {
self.ident_string(name),
FieldT::from(self.int(field)),
)),
[Leaf(Ident, b"array"), k, v] => Ok(Op::Array(self.sort(k), self.sort(v))),
[Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))),
[Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))),
[Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))),
[Leaf(Ident, b"fill"), key_sort, size] => {
Ok(Op::Fill(self.sort(key_sort), self.usize(size)))
}
_ => todo!("Unparsed op: {}", tt),
},
_ => todo!("Unparsed op: {}", tt),
@@ -691,13 +710,31 @@ impl<'src> IrInterp<'src> {
assert!(tts.len() >= 3);
let (metadata, input_names) = self.metadata(&tts[0]);
let precomputes = self.precompute(&tts[1]);
let iter = tts.iter().skip(2);
let mut persistent_arrays = Vec::new();
let mut skip_one = false;
if let List(tts_inner) = &tts[2] {
if tts_inner[0] == Leaf(Token::Ident, b"persistent_arrays") {
skip_one = true;
for tti in tts_inner.iter().skip(1) {
let ttis = self.unwrap_list(tti, "persistent_arrays");
let id = self.ident_string(&ttis[0]);
let _size = self.usize(&ttis[1]);
let term = self.term(&ttis[2]);
persistent_arrays.push((id, term));
}
}
}
let mut iter = tts.iter().skip(2);
if skip_one {
iter.next();
}
let outputs = iter.map(|tti| self.term(tti)).collect();
self.unbind(input_names);
Computation {
outputs,
metadata,
precomputes,
persistent_arrays,
}
}
@@ -816,6 +853,14 @@ pub fn serialize_computation(c: &Computation) -> String {
let mut out = String::new();
writeln!(&mut out, "(computation \n{}", c.metadata).unwrap();
writeln!(&mut out, "{}", serialize_precompute(&c.precomputes)).unwrap();
if !c.persistent_arrays.is_empty() {
writeln!(&mut out, "(persistent_arrays").unwrap();
for (name, term) in &c.persistent_arrays {
let size = check(term).as_array().2;
writeln!(&mut out, " ({name} {size} {})", serialize_term(term)).unwrap();
}
writeln!(&mut out, "\n)").unwrap();
}
for o in &c.outputs {
writeln!(&mut out, "\n {}", serialize_term(o)).unwrap();
}
@@ -907,6 +952,43 @@ mod test {
assert_eq!(t, t2);
}
#[test]
fn arr_cstore_roundtrip() {
let t = parse_term(
b"
(declare (
(a bool)
(b bool)
(c bool)
(A (array bool bool 1))
)
(let (
(B (cstore A a b c))
) (xor (select B a)
(select (#a (bv 4) false 4 ((#b0000 true))) #b0000))))",
);
let s = serialize_term(&t);
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
#[test]
fn arr_op_roundtrip() {
let t = parse_term(
b"
(declare (
(a bool)
(b bool)
(A (array bool bool 1))
)
(= A ((array bool bool) a))
)",
);
let s = serialize_term(&t);
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
#[test]
fn tup_roundtrip() {
let t = parse_term(
@@ -1008,6 +1090,38 @@ mod test {
assert_eq!(c, c2);
}
#[test]
fn computation_roundtrip_persistent_arrays() {
let c = parse_computation(
b"
(computation
(metadata
(parties P V)
(inputs
(a bool (party 0) (random) (round 1))
(b bool)
(A (tuple bool bool))
(x bool (party 0) (committed))
)
(commitments)
)
(precompute
((c bool) (d bool))
((a bool))
(tuple (not (and c d)))
)
(persistent_arrays (AA 2 (#a (bv 4) false 4 ((#b0000 true)))))
(let (
(B ((update 1) A b))
) (xor ((field 1) B)
((field 0) (#t false false #b0000 true))))
)",
);
let s = serialize_computation(&c);
let c2 = parse_computation(s.as_bytes());
assert_eq!(c, c2);
}
#[test]
fn challenge_roundtrip() {
let t = parse_term(b"(declare ((a bool) (b bool)) ((challenge hithere 17) a b))");
@@ -1015,4 +1129,73 @@ mod test {
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
#[test]
fn persistent_ram_split_roundtrip() {
let t = parse_term(
b"
(declare (
(entries (array (mod 17) (tuple (mod 17) (mod 17)) 5))
(indices (array (mod 17) (mod 17) 3))
)
(persistent_ram_split entries indices))",
);
let s = serialize_term(&t);
println!("{s}");
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
#[test]
fn list_value_equiv_to_array() {
let t_array = parse_term(b"(declare () (#a (bv 4) #x0 3 ((#x0 #x0) (#x1 #x1) (#x2 #x4))))");
let t_list = parse_term(b"(declare () (#l (bv 4) (#x0 #x1 #x4)))");
assert_eq!(t_array, t_list);
let s = serialize_term(&t_array);
let t_roundtripped = parse_term(s.as_bytes());
assert_eq!(t_array, t_roundtripped);
}
#[test]
fn pf2bool_trusted_rountrip() {
let t = parse_term(b"(declare ((a bool)) (pf2bool_trusted (ite a #f1m11 #f0m11)))");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn op_sort_rountrip() {
let t = parse_term(b"(declare () (sort (#l (mod 3) ((#t true true) (#t true false)))))");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn fill_roundtrip() {
let t = parse_term(b"(declare () ((fill (bv 4) 3) #x00))");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn pf_fits_in_bits_rountrip() {
let t = parse_term(b"(declare ((a bool)) ((pf_fits_in_bits 4) (ite a #f1m11 #f0m11)))");
let t2 = parse_term(serialize_term(&t).as_bytes());
assert_eq!(t, t2);
}
#[test]
fn uniq_deri_gcd_roundtrip() {
let t = parse_term(
b"
(declare (
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
)
(uniq_deri_gcd pairs))",
);
let s = serialize_term(&t);
println!("{s}");
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
}

View File

@@ -60,14 +60,20 @@ fn check_dependencies(t: &Term) -> Vec<Term> {
Op::IntBinPred(_) => Vec::new(),
Op::UbvToPf(_) => Vec::new(),
Op::PfChallenge(_, _) => Vec::new(),
Op::PfFitsInBits(_) => Vec::new(),
Op::Select => vec![t.cs()[0].clone()],
Op::Store => vec![t.cs()[0].clone()],
Op::Array(..) => Vec::new(),
Op::CStore => vec![t.cs()[0].clone()],
Op::Fill(..) => vec![t.cs()[0].clone()],
Op::Tuple => t.cs().to_vec(),
Op::Field(_) => vec![t.cs()[0].clone()],
Op::Update(_i) => vec![t.cs()[0].clone()],
Op::Map(_) => t.cs().to_vec(),
Op::Call(_, _, _) => Vec::new(),
Op::Rot(_) => vec![t.cs()[0].clone()],
Op::PfToBoolTrusted => Vec::new(),
Op::ExtOp(o) => o.check_dependencies(t),
}
}
@@ -130,8 +136,20 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
Op::IntBinPred(_) => Ok(Sort::Bool),
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
Op::PfChallenge(_, m) => Ok(Sort::Field(m.clone())),
Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v)| v.clone()),
Op::PfFitsInBits(_) => Ok(Sort::Bool),
Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v, _)| v.clone()),
Op::Store => Ok(get_ty(&t.cs()[0]).clone()),
Op::Array(k, v) => Ok(Sort::Array(
Box::new(k.clone()),
Box::new(v.clone()),
t.cs().len(),
)),
Op::CStore => Ok(get_ty(&t.cs()[0]).clone()),
Op::Fill(key_sort, size) => Ok(Sort::Array(
Box::new(key_sort.clone()),
Box::new(get_ty(&t.cs()[0]).clone()),
*size,
)),
Op::Tuple => Ok(Sort::Tuple(t.cs().iter().map(get_ty).cloned().collect())),
Op::Field(i) => {
let sort = get_ty(&t.cs()[0]);
@@ -165,7 +183,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
for i in 0..arg_cnt {
match array_or(get_ty(&t.cs()[i]), "map inputs") {
Ok((_, v)) => {
Ok((_, v, _)) => {
arg_sorts_to_inner_op.push(v);
}
Err(e) => {
@@ -183,6 +201,11 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
}
Op::Call(_, _, ret) => Ok(ret.clone()),
Op::Rot(_) => Ok(get_ty(&t.cs()[0]).clone()),
Op::PfToBoolTrusted => Ok(Sort::Bool),
Op::ExtOp(o) => {
let args_sorts: Vec<&Sort> = t.cs().iter().map(|c| get_ty(c)).collect();
o.check(&args_sorts)
}
o => Err(TypeErrorReason::Custom(format!("other operator: {o}"))),
}
}
@@ -335,6 +358,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
}
(Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())),
(Op::PfChallenge(_, m), _) => Ok(Sort::Field(m.clone())),
(Op::PfFitsInBits(_), &[a]) => pf_or(a, "pf fits in bits").map(|_| Sort::Bool),
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
(Op::IntNaryOp(_), a) => {
let ctx = "int nary op";
@@ -349,6 +373,21 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
(Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store")
.and_then(|_| eq_or(v, b, "store"))
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
(Op::CStore, &[Sort::Array(k, v, n), a, b, c]) => eq_or(k, a, "cstore")
.and_then(|_| eq_or(v, b, "cstore"))
.and_then(|_| bool_or(c, "cstore"))
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
(Op::Fill(key_sort, size), &[v]) => Ok(Sort::Array(
Box::new(key_sort.clone()),
Box::new(v.clone()),
*size,
)),
(Op::Array(k, v), a) => {
let ctx = "array op";
a.iter()
.try_fold((), |(), ai| eq_or(v, ai, ctx))
.map(|_| Sort::Array(Box::new(k.clone()), Box::new(v.clone()), a.len()))
}
(Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())),
(Op::Field(i), &[a]) => tuple_or(a, "tuple field access").and_then(|t| {
if i < &t.len() {
@@ -419,6 +458,8 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
(Op::Rot(_), &[Sort::Array(k, v, n)]) => bv_or(k, "rot key")
.and_then(|_| bv_or(v, "rot val"))
.map(|_| Sort::Array(k.clone(), v.clone(), *n)),
(Op::PfToBoolTrusted, &[k]) => pf_or(k, "pf to bool argument").map(|_| Sort::Bool),
(Op::ExtOp(o), _) => o.check(a),
(_, _) => Err(TypeErrorReason::Custom("other".to_string())),
}
}
@@ -525,9 +566,12 @@ fn int_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReaso
}
}
fn array_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<(&'a Sort, &'a Sort), TypeErrorReason> {
if let Sort::Array(k, v, _) = a {
Ok((k, v))
pub(super) fn array_or<'a>(
a: &'a Sort,
ctx: &'static str,
) -> Result<(&'a Sort, &'a Sort, usize), TypeErrorReason> {
if let Sort::Array(k, v, size) = a {
Ok((k, v, *size))
} else {
Err(TypeErrorReason::ExpectedArray(a.clone(), ctx))
}
@@ -559,21 +603,21 @@ fn fp_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
}
}
fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
pub(super) fn pf_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
match a {
Sort::Field(_) => Ok(a),
_ => Err(TypeErrorReason::ExpectedPf(a.clone(), ctx)),
}
}
fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], TypeErrorReason> {
pub(super) fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], TypeErrorReason> {
match a {
Sort::Tuple(a) => Ok(a),
_ => Err(TypeErrorReason::ExpectedTuple(ctx)),
}
}
fn eq_or(a: &Sort, b: &Sort, ctx: &'static str) -> Result<(), TypeErrorReason> {
pub(super) fn eq_or(a: &Sort, b: &Sort, ctx: &'static str) -> Result<(), TypeErrorReason> {
if a == b {
Ok(())
} else {

View File

@@ -224,12 +224,11 @@ mod tests {
);
let costs = CostModel::from_opa_cost_file(&p);
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![BV_MUL;
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
leaf_term(Op::Var("b".to_owned(), Sort::BitVector(32)))
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let _assignment = build_ilp(&cs, &costs);
}
@@ -242,7 +241,6 @@ mod tests {
);
let costs = CostModel::from_opa_cost_file(&p);
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![Op::Eq;
term![BV_MUL;
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
@@ -268,7 +266,7 @@ mod tests {
],
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32)))
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let assignment = build_ilp(&cs, &costs);
// Big enough to do the math with arith
@@ -288,7 +286,6 @@ mod tests {
);
let costs = CostModel::from_opa_cost_file(&p);
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![Op::Eq;
term![BV_MUL;
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))),
@@ -305,7 +302,7 @@ mod tests {
],
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32)))
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let assignment = build_ilp(&cs, &costs);
// All yao

View File

@@ -708,8 +708,7 @@ mod test {
leaf_term(Op::Var("a".to_owned(), Sort::Bool)),
leaf_term(Op::Var("b".to_owned(), Sort::Bool))],
],
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let r = ilp.solve(default_solver).unwrap().1;
@@ -850,8 +849,7 @@ mod test {
fn trivial_bv_opt() {
let cs = Computation {
outputs: vec![leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)))],
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let (max, vars) = ilp.solve(default_solver).unwrap();
@@ -866,8 +864,7 @@ mod test {
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
bv_lit(1,4)
]],
metadata: ComputationMetadata::default(),
precomputes: Default::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let (max, vars) = ilp.solve(default_solver).unwrap();
@@ -877,12 +874,11 @@ mod test {
#[test]
fn mul2_bv_opt() {
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![BV_MUL;
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
bv_lit(2,4)
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let (max, _vars) = ilp.solve(default_solver).unwrap();
@@ -891,7 +887,6 @@ mod test {
#[test]
fn mul2_plus_bv_opt() {
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![BV_ADD;
term![BV_MUL;
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))),
@@ -900,7 +895,7 @@ mod test {
leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)))
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let (max, vars) = ilp.solve(default_solver).unwrap();
@@ -912,12 +907,11 @@ mod test {
let a = leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)));
let c = leaf_term(Op::Var("c".to_owned(), Sort::Bool));
let cs = Computation {
precomputes: Default::default(),
outputs: vec![term![BV_ADD;
term![ITE; c, bv_lit(2,4), bv_lit(1,4)],
term![BV_MUL; a, bv_lit(2,4)]
]],
metadata: ComputationMetadata::default(),
..Default::default()
};
let ilp = to_ilp(cs);
let (max, vars) = ilp.solve(default_solver).unwrap();

View File

@@ -5,7 +5,7 @@ from os import path
feature_path = ".features.txt"
mode_path = ".mode.txt"
cargo_features = {"aby", "c", "lp", "r1cs", "kahip", "kahypar",
"smt", "zok", "datalog", "bellman", "spartan"}
"smt", "zok", "datalog", "bellman", "spartan", "poly"}
# Environment variables
ABY_SOURCE = "./../ABY"