Better serde for Computation, PreComp, R1cs (#120)

The previous implementation for all 3 structures would separately serialize each term in a precomputation. This was very bad, because precomputations can be quite large, and the different terms typically share large numbers of sub-terms. This can be quadratic blowup.

Change list:

    add input/output type metadata to PreComp (this makes the type of variables in a PreComp clear)
    add textual serialization and parsing for PreComp (and include this in Computation)
    add serde implementations for Computation and PreCop based on the text formats.
    modify the R1cs serde derive so that the Vec<Term> is tuplized before serde

This reduces our prover-key size for verifyEddsa.zok from wildly large (>450GB) to 171MB.

Collin and Anna first observed this bug. Collin's PR #118 works around it (among other improvements).
This commit is contained in:
Alex Ozdemir
2022-11-08 15:47:51 -08:00
committed by GitHub
parent 25773910e5
commit 051809a554
9 changed files with 294 additions and 28 deletions

View File

@@ -0,0 +1,12 @@
def main(private field[3][3] A, private field[3][3] B) -> field[3][3]:
field [3][3] AB = [[0; 3]; 3]
for field i in 0..3 do
for field j in 0..3 do
for field k in 0..3 do
AB[i][j] = AB[i][j] + A[i][k] * B[k][j]
endfor
endfor
endfor
return AB

View File

@@ -0,0 +1,12 @@
def main(private field[4][4] A, private field[4][4] B) -> field[4][4]:
field [4][4] AB = [[0; 4]; 4]
for field i in 0..4 do
for field j in 0..4 do
for field k in 0..4 do
AB[i][j] = AB[i][j] + A[i][k] * B[k][j]
endfor
endfor
endfor
return AB

View File

@@ -0,0 +1,12 @@
def main(private field[5][5] A, private field[5][5] B) -> field[5][5]:
field [5][5] AB = [[0; 5]; 5]
for field i in 0..5 do
for field j in 0..5 do
for field k in 0..5 do
AB[i][j] = AB[i][j] + A[i][k] * B[k][j]
endfor
endfor
endfor
return AB

View File

@@ -116,6 +116,7 @@ mod test {
b"
(computation
(metadata () ((a (bv 4)) (b (bv 4)) (c (bv 4))) ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #x0 4 ()))
@@ -138,6 +139,7 @@ mod test {
b"
(computation
(metadata () ((a (mod 5)) (b (mod 5)) (c (mod 5))) ())
(precompute () () (#t ))
(let
(
(c_array (#a (mod 5) #f1m5 4 ()))
@@ -160,6 +162,7 @@ mod test {
b"
(computation
(metadata () ((a (bv 4)) (b (bv 4)) (c (bv 4))) ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #x0 4 ()))
@@ -182,6 +185,7 @@ mod test {
b"
(computation
(metadata () ((a (mod 5)) (b (mod 5)) (c (mod 5))) ())
(precompute () () (#t ))
(let
(
(c_array (#a (mod 5) #f1m5 4 ()))

View File

@@ -331,6 +331,7 @@ mod test {
b"
(computation
(metadata () () ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #x0 4 ()))
@@ -355,6 +356,7 @@ mod test {
b"
(computation
(metadata () () ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #b000 4 ()))
@@ -392,6 +394,7 @@ mod test {
b"
(computation
(metadata () ((a bool)) ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #b000 4 ()))
@@ -430,6 +433,7 @@ mod test {
b"
(computation
(metadata () ((a bool)) ())
(precompute () () (#t ))
(let
(
(c_array (#a (bv 4) #b000 4 ()))
@@ -469,6 +473,7 @@ mod test {
b"
(computation
(metadata () ((a bool)) ())
(precompute () () (#t ))
(let
(
; connected component 0: simple store chain

View File

@@ -1982,7 +1982,7 @@ impl Display for ComputationMetadata {
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
/// An IR computation.
pub struct Computation {
/// The outputs of the computation.
@@ -2095,6 +2095,42 @@ impl Computation {
}
}
impl Serialize for Computation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes = text::serialize_computation(self);
serializer.serialize_str(&bytes)
}
}
struct ComputationDeserVisitor;
impl<'de> Visitor<'de> for ComputationDeserVisitor {
type Value = Computation;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a string (that textually defines a term)")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: std::error::Error,
{
Ok(text::parse_computation(v.as_bytes()))
}
}
impl<'de> Deserialize<'de> for Computation {
fn deserialize<D>(deserializer: D) -> Result<Computation, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ComputationDeserVisitor)
}
}
#[derive(Clone, Debug, Default)]
/// A map of IR computations.
pub struct Computations {

View File

@@ -10,11 +10,12 @@ use crate::ir::term::*;
/// A "precomputation".
///
/// Expresses a computation to be run in advance by a single party.
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PreComp {
/// A map from output names to the terms that compute them.
outputs: FxHashMap<String, Term>,
sequence: Vec<String>,
sequence: Vec<(String, Sort)>,
inputs: FxHashSet<(String, Sort)>,
}
impl PreComp {
@@ -26,9 +27,22 @@ impl PreComp {
pub fn outputs(&self) -> &FxHashMap<String, Term> {
&self.outputs
}
/// immutable access to the outputs
pub fn sequence(&self) -> &[(String, Sort)] {
&self.sequence
}
/// immutable access to the outputs
pub fn inputs(&self) -> &FxHashSet<(String, Sort)> {
&self.inputs
}
/// Add an input
pub fn add_input(&mut self, name: String, sort: Sort) {
self.inputs.insert((name, sort));
}
/// Add a new output variable to the precomputation. `value` is the term that computes its value.
pub fn add_output(&mut self, name: String, value: Term) {
self.sequence.push(name.clone());
let sort = check(&value);
self.sequence.push((name.clone(), sort));
let old = self.outputs.insert(name, value);
assert!(old.is_none());
}
@@ -49,15 +63,16 @@ impl PreComp {
}
}
seq.retain(|s| {
let o = os.get(s).unwrap();
seq.retain(|(name, _sort)| {
let o = os.get(name).unwrap();
let drop = to_remove.contains(o);
if drop {
os.remove(s);
os.remove(name);
}
!drop
});
}
/// Evaluate the precomputation.
///
/// Requires an input environment that binds all inputs for the underlying computation.
@@ -65,40 +80,86 @@ impl PreComp {
let mut value_cache: TermMap<Value> = TermMap::new();
let mut env = env.clone();
// iterate over all terms, evaluating them using the cache.
for o_name in &self.sequence {
for (o_name, _o_sort) in &self.sequence {
let o = self.outputs.get(o_name).unwrap();
eval_cached(o, &env, &mut value_cache);
env.insert(o_name.clone(), value_cache.get(o).unwrap().clone());
}
env
}
/// Compute the inputs for this precomputation
pub fn inputs_to_terms(&self) -> FxHashMap<String, Term> {
PostOrderIter::new(term(Op::Tuple, self.outputs.values().cloned().collect()))
.filter_map(|t| match &t.op {
Op::Var(name, _) => Some((name.clone(), t.clone())),
_ => None,
})
.collect()
/// Get all outputs, in seqence, as a tuple
pub fn tuple(&self) -> Term {
term(
Op::Tuple,
self.sequence
.iter()
.map(|o| self.outputs.get(&o.0).unwrap())
.cloned()
.collect(),
)
}
/// Compute the inputs for this precomputation
pub fn inputs(&self) -> FxHashSet<String> {
self.inputs_to_terms().into_keys().collect()
/// Recompute the inputs.
fn recompute_inputs(&mut self) {
let mut inputs = FxHashSet::default();
for t in PostOrderIter::new(self.tuple()) {
if let Op::Var(name, sort) = &t.op {
inputs.insert((name.clone(), sort.clone()));
}
}
self.inputs = inputs;
}
/// Bind the outputs of `self` to the inputs of `other`.
pub fn sequential_compose(mut self, other: &PreComp) -> PreComp {
for o_name in &other.sequence {
for (o_name, o_sort) in &other.sequence {
let o = other.outputs.get(o_name).unwrap().clone();
assert!(!self.outputs.contains_key(o_name));
self.outputs.insert(o_name.clone(), o);
self.sequence.push(o_name.clone());
self.sequence.push((o_name.clone(), o_sort.clone()));
}
self.recompute_inputs();
self
}
}
impl Serialize for PreComp {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes = text::serialize_precompute(self);
serializer.serialize_str(&bytes)
}
}
struct PreCompDeserVisitor;
impl<'de> Visitor<'de> for PreCompDeserVisitor {
type Value = PreComp;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a string (that textually defines a term)")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: std::error::Error,
{
Ok(text::parse_precompute(v.as_bytes()))
}
}
impl<'de> Deserialize<'de> for PreComp {
fn deserialize<D>(deserializer: D) -> Result<PreComp, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(PreCompDeserVisitor)
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -122,12 +183,18 @@ mod test {
let mut p_with_a = p.clone();
p_with_a.restrict_to_inputs(vec!["a".into()].into_iter().collect());
assert_eq!(p_with_a.sequence, vec!["out1"]);
assert_eq!(
p_with_a.sequence.iter().map(|(n, _)| n).collect::<Vec<_>>(),
vec!["out1"]
);
assert_eq!(p_with_a.outputs.len(), 1);
let mut p_with_b = p.clone();
p_with_b.restrict_to_inputs(vec!["b".into()].into_iter().collect());
assert_eq!(p_with_b.sequence, vec!["out2"]);
assert_eq!(
p_with_b.sequence.iter().map(|(n, _)| n).collect::<Vec<_>>(),
vec!["out2"]
);
assert_eq!(p_with_b.outputs.len(), 1);
let mut p_both = p.clone();

View File

@@ -6,6 +6,8 @@
//!
//! Includes a parser ([parse_computation]) and serializer ([serialize_computation]) for [Computation]s.
//!
//! Includes a parser ([parse_precompute]) and serializer ([serialize_precompute]) for [precomp::PreComp]s.
//!
//!
//! * IR Textual format
//! * It's s-expressions.
@@ -13,11 +15,15 @@
//! * `I`: integer (arbitrary-precision)
//! * `X`: identifier
//! * regex: `[^()0-9#; \t\n\f][^(); \t\n\f#]*`
//! * Computation `C`: `(computation M T)`
//! * Computation `C`: `(computation M P T)`
//! * Metadata `M`: `(metadata PARTIES INPUTS VISIBILITIES)`
//! * PARTIES is `(X1 .. Xn)`
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
//! * VISIBILITIES is `((X_INPUT_1 X_PARTY_1) .. (X_INPUT_n X_PARTY_n))`
//! * Precompute `P`: `(precompute INPUTS OUTPUTS TUPLE_TERM)`
//! * INPUTS is `((X1 S1) .. (Xn Sn))`
//! * OUTPUTS is `((X1 S1) .. (Xn Sn))`
//! * TUPLE_TERM is a tuple of the same arity as the output
//! * Sort `S`:
//! * `bool`
//! * `f32`
@@ -571,19 +577,65 @@ impl<'src> IrInterp<'src> {
tt
)
}
assert!(tts.len() > 2);
assert!(tts.len() > 3);
let (metadata, input_names) = self.metadata(&tts[1]);
let outputs = tts[2..].iter().map(|tti| self.term(tti)).collect();
let precomputes = self.precompute(&tts[2]);
let outputs = tts[3..].iter().map(|tti| self.term(tti)).collect();
self.unbind(input_names);
Computation {
outputs,
metadata,
precomputes: Default::default(),
precomputes,
}
} else {
panic!("Expected computation, found {}", tt)
}
}
fn var_decl_list(&mut self, tt: &TokTree<'src>) -> Vec<(String, Sort)> {
let input_names = self.decl_list(tt);
input_names
.iter()
.map(|i| (from_utf8(i).unwrap().into(), check(self.get_binding(i))))
.collect()
}
/// Parse a pre-computation.
pub fn precompute(&mut self, tt: &TokTree<'src>) -> precomp::PreComp {
let mut p = precomp::PreComp::new();
if let List(tts) = tt {
if tts.is_empty() || tts[0] != Leaf(Token::Ident, b"precompute") {
panic!(
"Expected precompute, but list did not start with 'precompute': {}",
tt
)
}
assert!(
tts.len() == 4,
"precompute should have 4 children, but has {}",
tts.len()
);
let inputs = self.var_decl_list(&tts[1]);
let outputs = self.var_decl_list(&tts[2]);
let tuple_term = self.term(&tts[3]);
assert!(
outputs.len() == tuple_term.cs.len(),
"output list has {} items, tuple has {}",
outputs.len(),
tuple_term.cs.len()
);
for (n, s) in inputs {
p.add_input(n, s);
}
for ((n, s), t) in outputs.into_iter().zip(&tuple_term.cs) {
assert_eq!(s, check(t));
p.add_output(n, t.clone());
}
p
} else {
panic!("Expected computation, found {}", tt)
}
}
}
/// Parse a term.
@@ -684,6 +736,7 @@ pub fn parse_computation(src: &[u8]) -> Computation {
pub fn serialize_computation(c: &Computation) -> String {
let mut out = String::new();
writeln!(&mut out, "(computation \n{}", c.metadata).unwrap();
writeln!(&mut out, "{}", serialize_precompute(&c.precomputes)).unwrap();
for o in &c.outputs {
writeln!(&mut out, "\n {}", serialize_term(o)).unwrap();
}
@@ -691,6 +744,30 @@ pub fn serialize_computation(c: &Computation) -> String {
out
}
/// Serialize a pre-computation.
pub fn serialize_precompute(p: &precomp::PreComp) -> String {
let mut out = String::new();
writeln!(&mut out, "(precompute (").unwrap();
for (name, sort) in p.inputs() {
writeln!(&mut out, " ({} {})", name, sort).unwrap();
}
writeln!(&mut out, ")(").unwrap();
for (name, sort) in p.sequence() {
writeln!(&mut out, " ({} {})", name, sort).unwrap();
}
writeln!(&mut out, ")").unwrap();
writeln!(&mut out, "\n {}", serialize_term(&p.tuple())).unwrap();
writeln!(&mut out, "\n)").unwrap();
out
}
/// Parse a pre-computation.
pub fn parse_precompute(src: &[u8]) -> precomp::PreComp {
let tree = parse_tok_tree(src);
let mut i = IrInterp::new();
i.precompute(&tree)
}
#[cfg(test)]
mod test {
use super::*;
@@ -810,6 +887,11 @@ mod test {
((a bool) (b bool) (A (tuple bool bool)))
((a P))
)
(precompute
((c bool) (d bool))
((a bool))
(tuple (not (and c d)))
)
(let (
(B ((update 1) A b))
) (xor ((field 1) B)
@@ -826,4 +908,19 @@ mod test {
let c2 = parse_computation(s.as_bytes());
assert_eq!(c, c2);
}
#[test]
fn precompute_roundtrip() {
let c = parse_precompute(
b"
(precompute
((c bool) (d bool))
((a bool) (b bool))
(tuple (not (and c d)) (not a))
)",
);
let s = serialize_precompute(&c);
let c2 = parse_precompute(s.as_bytes());
assert_eq!(c, c2);
}
}

View File

@@ -5,7 +5,7 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use log::debug;
use paste::paste;
use rug::Integer;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::hash_map::Entry;
use std::fmt::Display;
use std::hash::Hash;
@@ -28,9 +28,30 @@ pub struct R1cs<S: Hash + Eq> {
next_idx: usize,
public_idxs: HashSet<usize>,
constraints: Vec<(Lc, Lc, Lc)>,
#[serde(
deserialize_with = "deserialize_term_vec",
serialize_with = "serialize_term_vec"
)]
terms: Vec<Term>,
}
// serde requires a specific signature here.
#[allow(clippy::ptr_arg)]
fn serialize_term_vec<S>(ts: &Vec<Term>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
term(Op::Tuple, ts.clone()).serialize(s)
}
fn deserialize_term_vec<'de, D>(d: D) -> Result<Vec<Term>, D::Error>
where
D: Deserializer<'de>,
{
let tuple: Term = Deserialize::deserialize(d)?;
Ok(tuple.cs.clone())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
/// A linear combination
pub struct Lc {