optimize bellman ZK prover backend (#182)

three changes:
* faster integer->ff conversion
* parallel construction of bellman LCs
* parallel R1CS checking (on the CirC side)

The first change is the most important, by far. Our previous integer->ff
conversion was very slow.
This commit is contained in:
Alex Ozdemir
2024-01-03 16:39:05 -08:00
committed by GitHub
parent 6133414b44
commit cafb02b848
7 changed files with 37 additions and 28 deletions

1
Cargo.lock generated
View File

@@ -347,6 +347,7 @@ dependencies = [
"quickcheck_macros",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rayon",
"rsmt2",
"rug",
"rug-polynomial",

View File

@@ -25,6 +25,7 @@ log = "0.4"
thiserror = "1.0"
bellman = { git = "https://github.com/alex-ozdemir/bellman.git", branch = "mirage", optional = true }
rug-polynomial = { version = "0.2.5", optional = true }
rayon = { version = "1", optional = true }
ff = { version = "0.12", optional = true }
fxhash = "0.2"
good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false, optional = true }
@@ -69,10 +70,10 @@ datalog = ["pest", "pest-ast", "pest_derive", "from-pest", "lazy_static"]
smt = ["rsmt2", "ieee754"]
lp = ["good_lp", "lp-solvers"]
aby = ["lp"]
r1cs = ["bincode"]
r1cs = ["bincode", "rayon"]
poly = ["rug-polynomial"]
spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"]
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys", "byteorder"]
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys", "byteorder", "rayon"]
[[example]]
name = "circ"

View File

@@ -1225,7 +1225,7 @@ impl CGen {
}
/// Returns whether this was a builtin, and thus has been handled.
fn maybe_handle_builtins(&mut self, name: &String, args: &Vec<CTerm>) -> Option<CTerm> {
fn maybe_handle_builtins(&mut self, name: &String, args: &[CTerm]) -> Option<CTerm> {
if self.sv_functions && (name == "__VERIFIER_assert" || name == "__VERIFIER_assume") {
assert!(args.len() == 1);
let bool_arg = cast_to_bool(args[0].clone());
@@ -1241,7 +1241,7 @@ impl CGen {
}
}
fn fn_call(&mut self, name: &String, arg_sorts: &Vec<Sort>, rets: &Sort) {
fn fn_call(&mut self, name: &String, arg_sorts: &[Sort], rets: &Sort) {
debug!("Call: {}", name);
// Get function types

View File

@@ -71,7 +71,7 @@ impl CTermData {
pub fn term(&self, ctx: &CirCtx) -> Term {
let ts = self.terms(ctx);
assert!(ts.len() == 1);
ts.get(0).unwrap().clone()
ts.first().unwrap().clone()
}
pub fn simple_term(&self) -> Term {

View File

@@ -166,9 +166,9 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap {
ilp.maximize(
-conv_vars
.values()
.map(|(a, b)| (a, b))
.chain(term_vars.values().map(|(a, b, _)| (a, b)))
.fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost),
.cloned()
.chain(term_vars.values().map(|(a, b, _)| (*a, *b)))
.fold(0.0.into(), |acc: Expression, (v, cost)| acc + v * cost),
);
let (_opt, solution) = ilp.default_solve().unwrap();

View File

@@ -2,10 +2,10 @@
use ::bellman::{groth16, Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable};
use ff::{Field, PrimeField, PrimeFieldBits};
use fxhash::FxHashMap;
use gmp_mpfr_sys::gmp::limb_t;
use group::WnafGroup;
use log::debug;
use pairing::{Engine, MultiMillerLoop};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
@@ -23,15 +23,12 @@ use crate::ir::term::Value;
/// Convert a (rug) integer to a prime field element.
pub(super) fn int_to_ff<F: PrimeField>(i: Integer) -> F {
let mut accumulator = F::from(0);
let limb_bits = (std::mem::size_of::<limb_t>() as u64) << 3;
let limb_base = F::from(2).pow_vartime([limb_bits]);
// as_ref yeilds a least-significant-first array.
for digit in i.as_ref().iter().rev() {
accumulator *= limb_base;
accumulator += F::from(*digit);
}
accumulator
assert!(i >= 0);
let digits: Vec<u8> = i.to_digits(rug::integer::Order::LsfLe);
let mut repr = F::Repr::default();
assert!(digits.len() <= repr.as_ref().len());
repr.as_mut()[..digits.len()].copy_from_slice(&digits);
F::from_repr_vartime(repr).unwrap()
}
/// Convert one our our linear combinations to a bellman linear combination.
@@ -130,13 +127,22 @@ impl<'a, F: PrimeField> Circuit<F> for SynthInput<'a> {
};
vars.insert(var, v);
}
for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() {
cs.enforce(
|| format!("con{i}"),
|z| lc_to_bellman::<F, CS>(&vars, a, z),
|z| lc_to_bellman::<F, CS>(&vars, b, z),
|z| lc_to_bellman::<F, CS>(&vars, c, z),
);
let bellman_lcs: Vec<(_, _, _)> = self
.0
.r1cs
.constraints
.par_iter()
.map(|(a, b, c)| {
(
lc_to_bellman::<F, CS>(&vars, a, LinearCombination::zero()),
lc_to_bellman::<F, CS>(&vars, b, LinearCombination::zero()),
lc_to_bellman::<F, CS>(&vars, c, LinearCombination::zero()),
)
})
.collect();
for (i, (a, b, c)) in bellman_lcs.into_iter().enumerate() {
cs.enforce(|| format!("con{i}"), |_| a, |_| b, |_| c);
}
debug!(
"done with synth: {} vars {} cs",

View File

@@ -4,6 +4,7 @@ use circ_fields::{FieldT, FieldV};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use log::{debug, trace};
use paste::paste;
use rayon::prelude::*;
use rug::Integer;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
@@ -456,9 +457,9 @@ impl R1csFinal {
/// Check all assertions
fn check_all(&self, values: &HashMap<Var, FieldV>) {
for (a, b, c) in &self.constraints {
self.check(a, b, c, values)
}
self.constraints
.par_iter()
.for_each(|(a, b, c)| self.check(a, b, c, values));
}
}