This commit is contained in:
Edward Chen
2023-02-27 21:58:21 -05:00
40 changed files with 2574 additions and 846 deletions

View File

@@ -1,43 +1,47 @@
name: Build & Test
on:
push:
branches: [ master, ci ]
pull_request:
branches: [ master, ci ]
push:
branches: [master, ci]
pull_request:
branches: [master, ci]
env:
CARGO_TERM_COLOR: always
ABY_SOURCE: "./../ABY"
CARGO_TERM_COLOR: always
ABY_SOURCE: "./../ABY"
KAHIP_SOURCE: "./../KaHIP"
KAHYPAR_SOURCE: "./../kahypar"
jobs:
build:
build:
runs-on: ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install dependencies
if: runner.os == 'Linux'
run: sudo apt-get update; sudo apt-get install zsh cvc4 libboost-all-dev libssl-dev coinor-cbc coinor-libcbc-dev
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
- uses: Swatinem/rust-cache@v1
- name: Set all features on
run: python3 driver.py --all_features
- name: Install third_party libraries
run: python3 driver.py --install
- name: Cache third_party build
uses: actions/cache@v2
with:
path: ${ABY_SOURCE}/build
key: ${{ runner.os }}
- name: Check
run: python3 driver.py --check
- name: Format
run: cargo fmt -- --check
- name: Lint
run: python3 driver.py --lint
- name: Build, then Test
run: python3 driver.py --test
steps:
- uses: actions/checkout@v3
- name: Install dependencies
if: runner.os == 'Linux'
run: sudo apt-get update; sudo apt-get install zsh cvc4 libboost-all-dev libssl-dev coinor-cbc coinor-libcbc-dev
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
- name: Set all features on
run: python3 driver.py --all_features
- name: Cache third_party libraries
uses: actions/cache@v3
with:
path: |
/home/runner/work/circ/ABY
/home/runner/work/circ/KaHIP
/home/runner/work/circ/kahypar
key: ${{ runner.os }}-third_party
- name: Install third_party libraries
run: python3 driver.py --install
- name: Check
run: python3 driver.py --check
- name: Format
run: cargo fmt -- --check
- name: Lint
run: python3 driver.py --lint
- name: Build, then Test
run: python3 driver.py --test

1
Cargo.lock generated
View File

@@ -278,6 +278,7 @@ dependencies = [
"quickcheck",
"quickcheck_macros",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rsmt2",
"rug",
"serde",

View File

@@ -14,6 +14,7 @@ rug = { version = "1.11", features = ["serde"] }
gmp-mpfr-sys = { version = "1.4", optional = true }
lazy_static = { version = "1.4", optional = true }
rand = "0.8"
rand_chacha = "0.3"
rsmt2 = { version = "0.14", optional = true }
ieee754 = { version = "0.2", optional = true}
zokrates_parser = { path = "third_party/ZoKrates/zokrates_parser", optional = true }
@@ -65,8 +66,8 @@ datalog = ["pest", "pest-ast", "pest_derive", "from-pest", "lazy_static"]
smt = ["rsmt2", "ieee754"]
lp = ["good_lp", "lp-solvers"]
aby = ["lp"]
kahip = ["lp"]
kahypar = ["lp"]
kahip = []
kahypar = []
r1cs = []
spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"]
bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys"]

View File

@@ -20,14 +20,14 @@ use ff::Field;
use paste::paste;
use rug::Integer;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display, Formatter};
use std::fmt::{self, Debug, Display, Formatter};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::sync::Arc;
// TODO: rework this using macros?
/// Field element type
#[derive(PartialEq, Eq, Clone, Debug, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[derive(PartialEq, Eq, Clone, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum FieldT {
/// BLS12-381 scalar field as `ff`
FBls12381,
@@ -47,6 +47,12 @@ impl Display for FieldT {
}
}
impl Debug for FieldT {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl From<Arc<Integer>> for FieldT {
fn from(m: Arc<Integer>) -> Self {
Self::as_ff_opt(m.as_ref()).unwrap_or(Self::IntField(m))
@@ -128,7 +134,7 @@ impl FieldT {
}
/// Field element value
#[derive(PartialEq, Eq, Clone, Debug, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[derive(PartialEq, Eq, Clone, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum FieldV {
/// BLS12-381 scalar field element as `ff`
FBls12381(FBls12381),
@@ -272,15 +278,37 @@ impl FieldV {
ty.new_v(self.i())
}
}
/// Get this field element as an SMT-LIB string.
#[inline]
pub fn as_smtlib(&self) -> String {
match self {
Self::FBls12381(pf) => format!("#f{}m{}", Integer::from(pf), &*F_BLS12381_FMOD),
Self::FBn254(pf) => format!("#f{}m{}", Integer::from(pf), &*F_BN254_FMOD),
Self::IntField(i) => format!("{}", i),
}
}
/// Get this field element as a signed integer. No particular cut-off is guaranteed.
#[inline]
pub fn signed_int(&self) -> Integer {
let mut i = self.i();
if i.significant_bits() >= self.ty().modulus().significant_bits() - 1 {
i -= self.ty().modulus();
}
i
}
}
impl Display for FieldV {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::FBls12381(pf) => write!(f, "#f{}m{}", Integer::from(pf), &*F_BLS12381_FMOD),
Self::FBn254(pf) => write!(f, "#f{}m{}", Integer::from(pf), &*F_BN254_FMOD),
Self::IntField(i) => i.fmt(f),
}
write!(f, "{}", self.signed_int())
}
}
impl Debug for FieldV {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{} (in {})", self.signed_int(), self.ty())
}
}

View File

@@ -49,6 +49,10 @@ impl crate::Table<u8> for Table {
"hashconsing"
}
fn for_each(f: impl FnMut(&u8, &[Self::Node])) {
panic!()
}
fn reserve(num_nodes: usize) {
FACTORY.reserve(num_nodes);
}

View File

@@ -52,6 +52,10 @@ macro_rules! generate_hashcons_hashconsing {
"hashconsing"
}
fn for_each(f: impl FnMut(&$Op, &[Self::Node])) {
panic!()
}
fn reserve(num_nodes: usize) {
FACTORY.reserve(num_nodes);
}

View File

@@ -50,6 +50,10 @@ impl crate::Table<TemplateOp> for Table {
"hashconsing"
}
fn for_each(f: impl FnMut(&TemplateOp, &[Self::Node])) {
panic!()
}
fn reserve(num_nodes: usize) {
FACTORY.reserve(num_nodes);
}

View File

@@ -39,6 +39,9 @@ pub trait Table<Op> {
/// The name of the implementation
fn name() -> &'static str;
/// Fun a function on every node
fn for_each(f: impl FnMut(&Op, &[Self::Node]));
/// When the table garbage-collects a node with ID `id`, it will call `f(id)`. `f(id)` might
/// clear caches, etc., but instead of droping `Node`s, it should return them.
///

View File

@@ -45,6 +45,10 @@ impl crate::Table<u8> for Table {
"raw"
}
fn for_each(mut f: impl FnMut(&u8, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -48,6 +48,10 @@ macro_rules! generate_hashcons_raw {
"raw"
}
fn for_each(mut f: impl FnMut(&$Op, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -45,6 +45,10 @@ impl crate::Table<TemplateOp> for Table {
"raw"
}
fn for_each(mut f: impl FnMut(&TemplateOp, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -69,6 +69,10 @@ impl crate::Table<u8> for Table {
"rc"
}
fn for_each(mut f: impl FnMut(&u8, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -72,6 +72,10 @@ macro_rules! generate_hashcons_rc {
"rc"
}
fn for_each(mut f: impl FnMut(&$Op, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -69,6 +69,10 @@ impl crate::Table<TemplateOp> for Table {
"rc"
}
fn for_each(mut f: impl FnMut(&TemplateOp, &[Self::Node])) {
MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs)));
}
fn reserve(num_nodes: usize) {
MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes))
}

View File

@@ -86,6 +86,13 @@ Options:
[default: true]
[possible values: true, false]
--fmt-hide-field <HIDE_FIELD>
Always hide the field
[env: FMT_HIDE_FIELD=]
[default: false]
[possible values: true, false]
--zsharp-isolate-asserts <ISOLATE_ASSERTS>
In Z#, "isolate" assertions. That is, assertions in if/then/else expressions only take effect if that branch is active.
@@ -144,6 +151,8 @@ Options:
Which field to use [env: IR_FIELD_TO_BV=] [default: wrap] [possible values: wrap, panic]
--fmt-use-default-field <USE_DEFAULT_FIELD>
Which field to use [env: FMT_USE_DEFAULT_FIELD=] [default: true] [possible values: true, false]
--fmt-hide-field <HIDE_FIELD>
Always hide the field [env: FMT_HIDE_FIELD=] [default: false] [possible values: true, false]
--zsharp-isolate-asserts <ISOLATE_ASSERTS>
In Z#, "isolate" assertions. That is, assertions in if/then/else expressions only take effect if that branch is active [env: ZSHARP_ISOLATE_ASSERTS=] [default: false] [possible values: true, false]
--datalog-rec-limit <N>
@@ -179,6 +188,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -216,6 +226,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -251,6 +262,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -286,6 +298,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -321,6 +334,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -356,6 +370,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -391,6 +406,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -426,6 +442,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -464,6 +481,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -500,6 +518,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -538,6 +557,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: true,
@@ -574,6 +594,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: true,
@@ -612,6 +633,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,
@@ -648,6 +670,7 @@ BinaryOpt {
},
fmt: FmtOpt {
use_default_field: true,
hide_field: false,
},
zsharp: ZsharpOpt {
isolate_asserts: false,

View File

@@ -198,12 +198,20 @@ pub struct FmtOpt {
action = ArgAction::Set,
default_value = "true")]
pub use_default_field: bool,
/// Always hide the field
#[arg(
long = "fmt-hide-field",
env = "FMT_HIDE_FIELD",
action = ArgAction::Set,
default_value = "false")]
pub hide_field: bool,
}
impl Default for FmtOpt {
fn default() -> Self {
Self {
use_default_field: true,
hide_field: false,
}
}
}

View File

@@ -39,17 +39,15 @@ def install(features):
subprocess.run(["./scripts/build_aby.zsh"])
if f == "kahip":
if verify_path_empty(KAHIP_SOURCE):
# TODO: we can pull from their main repository instead of fork also long as we
# remove their parallel (ParHIP) cmake dependency
subprocess.run(
["git", "clone", "https://github.com/edwjchen/KaHIP.git", KAHIP_SOURCE]
["git", "clone", "https://github.com/KaHIP/KaHIP.git", KAHIP_SOURCE]
)
subprocess.run(["./scripts/build_kahip.zsh"])
if f == "kahypar":
if verify_path_empty(KAHYPAR_SOURCE):
subprocess.run(
["git", "clone", "--depth=1", "--recursive",
"git@github.com:SebastianSchlag/kahypar.git", KAHYPAR_SOURCE]
"https://github.com/SebastianSchlag/kahypar.git", KAHYPAR_SOURCE]
)
subprocess.run(["./scripts/build_kahypar.zsh"])
@@ -107,7 +105,7 @@ def build(features):
cmd += ["--examples"]
if features:
cmd = cmd + ["--features"] + features
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
@@ -139,8 +137,8 @@ def test(features, extra_args):
test_cmd = ["cargo", "test"]
test_cmd_release = ["cargo", "test", "--release"]
if features:
test_cmd += ["--features"] + features
test_cmd_release += ["--features"] + features
test_cmd += ["--features"] + [",".join(features)]
test_cmd_release += ["--features"] + [",".join(features)]
if "ristretto255" in features:
test_cmd += ["--no-default-features"]
test_cmd_release += ["--no-default-features"]
@@ -152,7 +150,7 @@ def test(features, extra_args):
if load_mode() == "release":
log_run_check(test_cmd_release)
if "r1cs" in features and "smt" in features:
if "r1cs" in features and "smt" in features and "datalog" in features:
log_run_check(["./scripts/test_datalog.zsh"])
if "zok" in features and "smt" in features:
@@ -191,7 +189,7 @@ def benchmark(features):
cmd += ["--examples"]
if features:
cmd = cmd + ["--features"] + features
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
log_run_check(cmd)
@@ -215,7 +213,7 @@ def lint():
cmd = ["cargo", "clippy", "--tests", "--examples", "--benches", "--bins"]
if features:
cmd = cmd + ["--features"] + features
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
log_run_check(cmd)
@@ -224,7 +222,7 @@ def lint():
def flamegraph(features, extra):
cmd = ["cargo", "flamegraph"]
if features:
cmd = cmd + ["--features"] + features
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
cmd += extra

View File

@@ -34,10 +34,10 @@ use circ::ir::{
use circ::target::aby::trans::to_aby;
#[cfg(feature = "lp")]
use circ::target::ilp::{assignment_to_values, trans::to_ilp};
#[cfg(feature = "bellman")]
use circ::target::r1cs::bellman::gen_params;
#[cfg(feature = "spartan")]
use circ::target::r1cs::spartan::write_data;
#[cfg(feature = "bellman")]
use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem};
#[cfg(feature = "r1cs")]
use circ::target::r1cs::{opt::reduce_linearities, trans::to_r1cs};
#[cfg(feature = "smt")]
@@ -96,6 +96,8 @@ enum Backend {
lc_elimination_thresh: usize,
#[arg(long, default_value = "count")]
action: ProofAction,
#[arg(long, default_value = "groth16")]
proof_impl: ProofImpl,
},
Smt {},
Ilp {},
@@ -135,6 +137,12 @@ enum ProofAction {
SpartanSetup,
}
#[derive(PartialEq, Eq, Debug, Clone, ValueEnum)]
enum ProofImpl {
Groth16,
Mirage,
}
fn determine_language(l: &Language, input_path: &Path) -> DeterminedLanguage {
match *l {
Language::Datalog => DeterminedLanguage::Datalog,
@@ -275,25 +283,24 @@ fn main() {
..
} => {
println!("Converting to r1cs");
let (mut prover_data, verifier_data) = to_r1cs(cs.get("main").clone(), cfg());
let cs = cs.get("main");
let mut r1cs = to_r1cs(cs, cfg());
println!(
"Pre-opt R1cs size: {}",
prover_data.r1cs.constraints().len()
);
prover_data.r1cs = reduce_linearities(prover_data.r1cs, cfg());
println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
r1cs = reduce_linearities(r1cs, cfg());
println!("Final R1cs size: {}", prover_data.r1cs.constraints().len());
println!("Final R1cs size: {}", r1cs.constraints().len());
let (prover_data, verifier_data) = r1cs.finalize(cs);
match action {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
ProofAction::Setup => {
println!("Generating Parameters");
gen_params::<Bls12, _, _>(
Bellman::<Bls12>::setup_fs(
prover_data,
verifier_data,
prover_key,
verifier_key,
&prover_data,
&verifier_data,
)
.unwrap();
}

View File

@@ -1,16 +1,18 @@
use circ::cfg::clap::{self, Parser, ValueEnum};
use circ::cfg::{
clap::{self, Parser, ValueEnum},
CircOpt,
};
use std::path::PathBuf;
#[cfg(feature = "bellman")]
use bls12_381::Bls12;
#[cfg(feature = "bellman")]
use circ::target::r1cs::bellman;
use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem};
#[cfg(feature = "spartan")]
use circ::target::r1cs::spartan;
#[cfg(any(feature = "bellman", feature = "spartan"))]
use circ::ir::term::text::parse_value_map;
#[cfg(feature = "spartan")]
use circ::target::r1cs::spartan;
#[derive(Debug, Parser)]
#[command(name = "zk", about = "The CirC ZKP runner")]
@@ -27,8 +29,12 @@ struct Options {
pin: PathBuf,
#[arg(long, default_value = "vin")]
vin: PathBuf,
#[arg(long, default_value = "groth16")]
proof_impl: ProofImpl,
#[arg(long)]
action: ProofAction,
#[command(flatten)]
circ: CircOpt,
}
#[derive(PartialEq, Debug, Clone, ValueEnum)]
@@ -40,24 +46,33 @@ enum ProofAction {
Spartan,
}
#[derive(PartialEq, Debug, Clone, ValueEnum)]
/// Whether to use Groth16 or Mirage
enum ProofImpl {
Groth16,
Mirage,
}
fn main() {
env_logger::Builder::from_default_env()
.format_level(false)
.format_timestamp(None)
.init();
let opts = Options::parse();
circ::cfg::set(&opts.circ);
match opts.action {
#[cfg(feature = "bellman")]
ProofAction::Prove => {
let input_map = parse_value_map(&std::fs::read(opts.inputs).unwrap());
println!("Proving");
bellman::prove::<Bls12, _, _>(opts.prover_key, opts.proof, &input_map).unwrap();
Bellman::<Bls12>::prove_fs(opts.prover_key, opts.inputs, opts.proof).unwrap();
}
#[cfg(feature = "bellman")]
ProofAction::Verify => {
let input_map = parse_value_map(&std::fs::read(opts.inputs).unwrap());
println!("Verifying");
bellman::verify::<Bls12, _, _>(opts.verifier_key, opts.proof, &input_map).unwrap();
assert!(
Bellman::<Bls12>::verify_fs(opts.verifier_key, opts.inputs, opts.proof).unwrap(),
"invalid proof"
);
}
#[cfg(feature = "spartan")]
ProofAction::Spartan => {

View File

@@ -129,16 +129,16 @@ fn main() {
*/
println!("Converting to r1cs");
let (prover_data, _) = to_r1cs(cs.get("main").clone(), cfg());
let r1cs = to_r1cs(cs.get("main"), cfg());
let r1cs = if options.skip_linred {
println!("Skipping linearity reduction, as requested.");
prover_data.r1cs
r1cs
} else {
println!(
"R1cs size before linearity reduction: {}",
prover_data.r1cs.constraints().len()
r1cs.constraints().len()
);
reduce_linearities(prover_data.r1cs, cfg())
reduce_linearities(r1cs, cfg())
};
println!("Final R1cs size: {}", r1cs.constraints().len());
match action {

View File

@@ -1,6 +1,5 @@
import os
from subprocess import Popen, PIPE
import sys
from typing import List
from tqdm import tqdm

View File

@@ -2,7 +2,7 @@
if [[ ! -z ${KAHIP_SOURCE} ]]; then
cd ${KAHIP_SOURCE}
./compile_withcmake.sh
./compile_withcmake.sh -DCMAKE_BUILD_TYPE=Release -DPARHIP=off
else
echo "Missing KAHIP_SOURCE environment variable."
fi

View File

@@ -35,20 +35,26 @@ function r1cs_test_count {
# Test prove workflow, given an example name
function pf_test {
ex_name=$1
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify
rm -rf P V pi
for proof_impl in groth16 mirage
do
ex_name=$1
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup --proof-impl $proof_impl
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove --proof-impl $proof_impl
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify --proof-impl $proof_impl
rm -rf P V pi
done
}
# Test prove workflow with --z-isolate-asserts, given an example name
function pf_test_isolate {
ex_name=$1
$BIN --zsharp-isolate-asserts true examples/ZoKrates/pf/$ex_name.zok r1cs --action setup
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify
rm -rf P V pi
for proof_impl in groth16 mirage
do
ex_name=$1
$BIN --zsharp-isolate-asserts true examples/ZoKrates/pf/$ex_name.zok r1cs --action setup --proof-impl $proof_impl
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove --proof-impl $proof_impl
$ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify --proof-impl $proof_impl
rm -rf P V pi
done
}
r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120

View File

@@ -385,7 +385,6 @@ mod test {
assert_eq!(bv_lit(1, 3), rams[0].accesses[0].val);
assert_eq!(bv_lit(2, 3), rams[0].accesses[1].val);
assert!(rams[0].accesses[2].val.is_var());
dbg!(cs2);
}
#[test]

View File

@@ -105,7 +105,7 @@ pub fn assert_all_vars_are_scalars(cs: &Computation) {
/// Check that every variables is a scalar.
fn remove_non_scalar_vars_from_main_computation(cs: &mut Computation) {
for input in cs.metadata.ordered_public_inputs() {
for input in cs.metadata.ordered_inputs() {
if !check(&input).is_scalar() {
cs.metadata.remove_var(input.as_var_name());
}

View File

@@ -161,3 +161,35 @@ pub fn array_elements(t: &Term) -> Vec<Term> {
pub fn array_to_tuple(t: &Term) -> Term {
term(Op::Tuple, array_elements(t))
}
/// Print operator stats
pub fn dump_op_stats() {
use std::mem::size_of;
println!("Op size: {}bytes", size_of::<Op>());
let mut counts: FxHashMap<Op, usize> = FxHashMap::default();
let mut children: FxHashMap<Op, usize> = FxHashMap::default();
let add = |map: &mut FxHashMap<Op, usize>, key: &Op, value: usize| {
if let Some(present_value) = map.get_mut(key) {
*present_value += value;
} else {
map.insert(key.clone(), value);
}
};
hc::Table::for_each(|op, cs| {
add(&mut counts, op, 1);
add(&mut children, op, cs.len());
});
let mut vector = Vec::new();
for k in counts.keys() {
let ct = *counts.get(k).unwrap();
let cs_ct = *children.get(k).unwrap();
let ave = cs_ct as f64 / ct as f64;
vector.push((k.clone(), ct, cs_ct, ave));
}
vector.sort_by_key(|t| t.1);
for (k, ct, cs_ct, ave) in vector {
let mem = size_of::<Op>() * ct + size_of::<Vec<Term>>() * ct + size_of::<Term>() * cs_ct;
let s: String = format!("{k}");
println!("Op {s:>20}, Count {ct:>8}, Children {cs_ct:>8}, Ave {ave:>8.2}, Mem {mem:>20}");
}
}

View File

@@ -23,6 +23,8 @@ struct IrFormatter<'a, 'b> {
pub struct IrCfg {
/// Whether to introduce a default modulus.
pub use_default_field: bool,
/// Whether to show any moduli
pub hide_field: bool,
}
impl IrCfg {
@@ -33,6 +35,7 @@ impl IrCfg {
if is_cfg_set() {
Self {
use_default_field: cfg().fmt.use_default_field,
hide_field: cfg().fmt.hide_field,
}
} else {
Self::parseable()
@@ -42,6 +45,7 @@ impl IrCfg {
pub fn parseable() -> Self {
Self {
use_default_field: true,
hide_field: false,
}
}
}
@@ -181,6 +185,7 @@ impl DisplayIr for Op {
Op::IntNaryOp(a) => write!(f, "{a}"),
Op::IntBinPred(a) => write!(f, "{a}"),
Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()),
Op::PfChallenge(n, m) => write!(f, "(challenge {} {})", n, m.modulus()),
Op::Select => write!(f, "select"),
Op::Store => write!(f, "store"),
Op::Tuple => write!(f, "tuple"),
@@ -244,10 +249,10 @@ impl DisplayIr for Array {
impl DisplayIr for FieldV {
fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult {
let omit_field = f
.default_field
.as_ref()
.map_or(false, |field| field == &self.ty());
let omit_field = f.cfg.hide_field
|| f.default_field
.as_ref()
.map_or(false, |field| field == &self.ty());
let mut i = self.i();
let mod_bits = self.modulus().significant_bits();
if i.significant_bits() + 1 >= mod_bits {
@@ -294,6 +299,12 @@ impl DisplayIr for VariableMetadata {
if self.random {
write!(f, " (random)")?;
}
if 0 != self.round {
write!(f, " (round {})", self.round)?;
}
if self.random {
write!(f, " (random)")?;
}
write!(f, ")")
}
}
@@ -333,7 +344,7 @@ fn fmt_term_with_bindings(t: &Term, f: &mut IrFormatter) -> FmtResult {
}
})
.collect();
if fields.len() == 1 {
if fields.len() == 1 && !f.cfg.hide_field {
f.default_field = fields.into_iter().next();
let i = f.default_field.clone().unwrap();
writeln!(f, "(set-default-modulus {}", i.modulus())?;
@@ -408,26 +419,20 @@ impl Display for Term {
}
}
impl Display for Sort {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
}
macro_rules! fmt_impl {
($trait:ty, $ty:ty) => {
impl $trait for $ty {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
}
}
};
}
impl Display for Value {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
}
}
impl Display for Op {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
}
}
impl Display for ComputationMetadata {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg()))
}
}
fmt_impl!(Display, Sort);
fmt_impl!(Display, Value);
fmt_impl!(Display, Op);
fmt_impl!(Display, ComputationMetadata);
fmt_impl!(Debug, Value);
fmt_impl!(Debug, Sort);
fmt_impl!(Debug, Op);

View File

@@ -48,7 +48,7 @@ pub mod ty;
pub use bv::BitVector;
pub use ty::{check, check_rec, TypeError, TypeErrorReason};
#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
/// An operator
pub enum Op {
/// a variable
@@ -125,6 +125,13 @@ pub enum Op {
///
/// Takes the modulus.
UbvToPf(FieldT),
/// A random value, sampled uniformly and independently of its arguments.
///
/// Takes a name (if deterministically sampled, challenges of different names are sampled
/// differentely) and a field to sample from.
///
/// In IR evaluation, we sample deterministically based on a hash of the name.
PfChallenge(String, FieldT),
/// Integer n-ary operator
IntNaryOp(IntNaryOp),
@@ -272,6 +279,7 @@ impl Op {
Op::FpToFp(_) => Some(1),
Op::PfUnOp(_) => Some(1),
Op::PfNaryOp(_) => None,
Op::PfChallenge(_, _) => None,
Op::IntNaryOp(_) => None,
Op::IntBinPred(_) => Some(2),
Op::UbvToPf(_) => Some(1),
@@ -635,7 +643,7 @@ impl<'de> Deserialize<'de> for Term {
}
}
#[derive(Clone, PartialEq, Debug, PartialOrd, Serialize, Deserialize)]
#[derive(Clone, PartialEq, PartialOrd, Serialize, Deserialize)]
/// An IR value (aka literal)
pub enum Value {
/// Bit-vector
@@ -776,7 +784,7 @@ impl std::hash::Hash for Value {
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
/// The "type" of an IR term
pub enum Sort {
/// bit-vectors of this width
@@ -1211,54 +1219,49 @@ pub fn eval(t: &Term, h: &FxHashMap<String, Value>) -> Value {
}
/// Helper function for eval function. Handles a single term
fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) -> Value {
let v = match &c.op() {
Op::Var(n, _) => h
fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, t: Term) -> Value {
let args: Vec<&Value> = t.cs().iter().map(|c| vs.get(c).unwrap()).collect();
let v = eval_op(t.op(), &args, h);
debug!("Eval {}\nAs {}", t, v);
vs.insert(t, v.clone());
v
}
/// Helper function for eval function. Handles a single op
#[allow(clippy::uninlined_format_args)]
pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> Value {
match op {
Op::Var(n, _) => var_vals
.get(n)
.unwrap_or_else(|| panic!("Missing var: {} in {:?}", n, h))
.unwrap_or_else(|| panic!("Missing var: {} in {:?}", n, var_vals))
.clone(),
Op::Eq => Value::Bool(vs.get(&c.cs()[0]).unwrap() == vs.get(&c.cs()[1]).unwrap()),
Op::Not => Value::Bool(!vs.get(&c.cs()[0]).unwrap().as_bool()),
Op::Implies => Value::Bool(
!vs.get(&c.cs()[0]).unwrap().as_bool() || vs.get(&c.cs()[1]).unwrap().as_bool(),
),
Op::BoolNaryOp(BoolNaryOp::Or) => {
Value::Bool(c.cs().iter().any(|c| vs.get(c).unwrap().as_bool()))
}
Op::BoolNaryOp(BoolNaryOp::And) => {
Value::Bool(c.cs().iter().all(|c| vs.get(c).unwrap().as_bool()))
}
Op::Eq => Value::Bool(args[0] == args[1]),
Op::Not => Value::Bool(!args[0].as_bool()),
Op::Implies => Value::Bool(!args[0].as_bool() || args[1].as_bool()),
Op::BoolNaryOp(BoolNaryOp::Or) => Value::Bool(args.iter().any(|a| a.as_bool())),
Op::BoolNaryOp(BoolNaryOp::And) => Value::Bool(args.iter().all(|a| a.as_bool())),
Op::BoolNaryOp(BoolNaryOp::Xor) => Value::Bool(
c.cs()
.iter()
.map(|c| vs.get(c).unwrap().as_bool())
args.iter()
.map(|a| a.as_bool())
.fold(false, std::ops::BitXor::bitxor),
),
Op::BvBit(i) => Value::Bool(
vs.get(&c.cs()[0])
.unwrap()
.as_bv()
.uint()
.get_bit(*i as u32),
),
Op::BvBit(i) => Value::Bool(args[0].as_bv().uint().get_bit(*i as u32)),
Op::BoolMaj => {
let c0 = vs.get(&c.cs()[0]).unwrap().as_bool() as u8;
let c1 = vs.get(&c.cs()[1]).unwrap().as_bool() as u8;
let c2 = vs.get(&c.cs()[2]).unwrap().as_bool() as u8;
let c0 = args[0].as_bool() as u8;
let c1 = args[1].as_bool() as u8;
let c2 = args[2].as_bool() as u8;
Value::Bool(c0 + c1 + c2 > 1)
}
Op::BvConcat => Value::BitVector({
let mut it = c.cs().iter().map(|c| vs.get(c).unwrap().as_bv().clone());
let mut it = args.iter().map(|a| a.as_bv().clone());
let f = it.next().unwrap();
it.fold(f, BitVector::concat)
}),
Op::BvExtract(h, l) => {
Value::BitVector(vs.get(&c.cs()[0]).unwrap().as_bv().clone().extract(*h, *l))
}
Op::BvExtract(h, l) => Value::BitVector(args[0].as_bv().clone().extract(*h, *l)),
Op::Const(v) => v.clone(),
Op::BvBinOp(o) => Value::BitVector({
let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone();
let b = vs.get(&c.cs()[1]).unwrap().as_bv().clone();
let a = args[0].as_bv().clone();
let b = args[1].as_bv().clone();
match o {
BvBinOp::Udiv => a / &b,
BvBinOp::Urem => a % &b,
@@ -1269,14 +1272,14 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
}
}),
Op::BvUnOp(o) => Value::BitVector({
let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone();
let a = args[0].as_bv().clone();
match o {
BvUnOp::Not => !a,
BvUnOp::Neg => -a,
}
}),
Op::BvNaryOp(o) => Value::BitVector({
let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_bv().clone());
let mut xs = args.iter().map(|a| a.as_bv().clone());
let f = xs.next().unwrap();
xs.fold(
f,
@@ -1290,13 +1293,13 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
)
}),
Op::BvSext(w) => Value::BitVector({
let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone();
let a = args[0].as_bv().clone();
let mask = ((Integer::from(1) << *w as u32) - 1)
* Integer::from(a.uint().get_bit(a.width() as u32 - 1));
BitVector::new(a.uint() | (mask << a.width() as u32), a.width() + w)
}),
Op::PfToBv(w) => Value::BitVector({
let i = vs.get(&c.cs()[0]).unwrap().as_pf().i();
let i = args[0].as_pf().i();
if let FieldToBv::Panic = cfg().ir.field_to_bv {
assert!(
(i.significant_bits() as usize) <= *w,
@@ -1307,19 +1310,13 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
BitVector::new(i % (Integer::from(1) << *w), *w)
}),
Op::BvUext(w) => Value::BitVector({
let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone();
let a = args[0].as_bv().clone();
BitVector::new(a.uint().clone(), a.width() + w)
}),
Op::Ite => if vs.get(&c.cs()[0]).unwrap().as_bool() {
vs.get(&c.cs()[1])
} else {
vs.get(&c.cs()[2])
}
.unwrap()
.clone(),
Op::Ite => args[if args[0].as_bool() { 1 } else { 2 }].clone(),
Op::BvBinPred(o) => Value::Bool({
let a = vs.get(&c.cs()[0]).unwrap().as_bv();
let b = vs.get(&c.cs()[1]).unwrap().as_bv();
let a = args[0].as_bv();
let b = args[1].as_bv();
match o {
BvBinPred::Sge => a.as_sint() >= b.as_sint(),
BvBinPred::Sgt => a.as_sint() > b.as_sint(),
@@ -1331,12 +1328,9 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
BvBinPred::Ult => a.uint() < b.uint(),
}
}),
Op::BoolToBv => Value::BitVector(BitVector::new(
Integer::from(vs.get(&c.cs()[0]).unwrap().as_bool()),
1,
)),
Op::BoolToBv => Value::BitVector(BitVector::new(Integer::from(args[0].as_bool()), 1)),
Op::PfUnOp(o) => Value::Field({
let a = vs.get(&c.cs()[0]).unwrap().as_pf().clone();
let a = args[0].as_pf().clone();
match o {
PfUnOp::Recip => {
if a.is_zero() {
@@ -1349,7 +1343,7 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
}
}),
Op::PfNaryOp(o) => Value::Field({
let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_pf().clone());
let mut xs = args.iter().map(|a| a.as_pf().clone());
let f = xs.next().unwrap();
xs.fold(
f,
@@ -1360,8 +1354,8 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
)
}),
Op::IntBinPred(o) => Value::Bool({
let a = vs.get(&c.cs()[0]).unwrap().as_int();
let b = vs.get(&c.cs()[1]).unwrap().as_int();
let a = args[0].as_int();
let b = args[1].as_int();
match o {
IntBinPred::Ge => a >= b,
IntBinPred::Gt => a > b,
@@ -1370,7 +1364,7 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
}
}),
Op::IntNaryOp(o) => Value::Int({
let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_int().clone());
let mut xs = args.iter().map(|a| a.as_int().clone());
let f = xs.next().unwrap();
xs.fold(
f,
@@ -1380,77 +1374,94 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
},
)
}),
Op::UbvToPf(fty) => Value::Field(fty.new_v(vs.get(&c.cs()[0]).unwrap().as_bv().uint())),
Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())),
Op::PfChallenge(name, field) => {
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use std::hash::{Hash, Hasher};
// hash the string
let mut hasher = fxhash::FxHasher::default();
name.hash(&mut hasher);
let hash: u64 = hasher.finish();
// seed ChaCha with the hash
let mut seed = [0u8; 32];
seed[0..8].copy_from_slice(&hash.to_le_bytes());
let mut rng = ChaChaRng::from_seed(seed);
// sample from ChaCha
Value::Field(field.random_v(&mut rng))
}
// tuple
Op::Tuple => Value::Tuple(c.cs().iter().map(|c| vs.get(c).unwrap().clone()).collect()),
Op::Tuple => Value::Tuple(args.iter().map(|a| (*a).clone()).collect()),
Op::Field(i) => {
let t = vs.get(&c.cs()[0]).unwrap().as_tuple();
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs()[0]);
let t = args[0].as_tuple();
assert!(i < &t.len(), "{} out of bounds for {} on {:?}", i, op, args);
t[*i].clone()
}
Op::Update(i) => {
let mut t = Vec::from(vs.get(&c.cs()[0]).unwrap().as_tuple()).into_boxed_slice();
assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs()[0]);
let e = vs.get(&c.cs()[1]).unwrap().clone();
let mut t = Vec::from(args[0].as_tuple()).into_boxed_slice();
assert!(i < &t.len(), "{} out of bounds for {} on {:?}", i, op, args);
let e = args[1].clone();
assert_eq!(t[*i].sort(), e.sort());
t[*i] = e;
Value::Tuple(t)
}
// array
Op::Store => {
let a = vs.get(&c.cs()[0]).unwrap().as_array().clone();
let i = vs.get(&c.cs()[1]).unwrap().clone();
let v = vs.get(&c.cs()[2]).unwrap().clone();
let a = args[0].as_array().clone();
let i = args[1].clone();
let v = args[2].clone();
Value::Array(a.store(i, v))
}
Op::Select => {
let a = vs.get(&c.cs()[0]).unwrap().as_array().clone();
let i = vs.get(&c.cs()[1]).unwrap();
let a = args[0].as_array().clone();
let i = args[1];
a.select(i)
}
Op::Map(op) => {
let arg_cnt = c.cs().len();
Op::Map(inner_op) => {
// term_vecs[i] will store a vector of all the i-th index entries of the array arguments
let mut term_vecs = vec![Vec::new(); vs.get(&c.cs()[0]).unwrap().as_array().size];
let mut arg_vecs: Vec<Vec<Value>> = vec![Vec::new(); args[0].as_array().size];
for i in 0..arg_cnt {
let arr = vs.get(&c.cs()[i]).unwrap().as_array().clone();
let iter = match check(&c.cs()[i]) {
for arg in args {
let arr = arg.as_array().clone();
let iter = match arg.sort() {
Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(),
_ => panic!("Input type should be Array"),
};
for (j, jval) in iter {
term_vecs[j].push(leaf_term(Op::Const(arr.clone().select(&jval))))
arg_vecs[j].push(arr.select(&jval))
}
}
let mut res = match check(&c) {
Sort::Array(k, v, n) => Array::default((*k).clone(), &v, n),
let term = term(
op.clone(),
args.iter()
.map(|a| leaf_term(Op::Const((*a).clone())))
.collect(),
);
let (mut res, iter) = match check(&term) {
Sort::Array(k, v, n) => (
Array::default((*k).clone(), &v, n),
(*k).clone().elems_iter_values().take(n).enumerate(),
),
_ => panic!("Output type of map should be array"),
};
let iter = match check(&c) {
Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(),
_ => panic!("Input type should be Array"),
};
for (i, idxval) in iter {
let t = term((**op).clone(), term_vecs[i].clone());
let val = eval_value(vs, h, t);
let args: Vec<&Value> = arg_vecs[i].iter().collect();
let val = eval_op(inner_op, &args, var_vals);
res.map.insert(idxval, val);
}
Value::Array(res)
}
Op::Rot(i) => {
let a = vs.get(&c.cs()[0]).unwrap().as_array().clone();
let iter = match check(&c.cs()[0]) {
Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(),
let a = args[0].as_array().clone();
let (mut res, iter, len) = match args[0].sort() {
Sort::Array(k, v, n) => (
Array::default((*k).clone(), &v, n),
(*k).clone().elems_iter_values().take(n).enumerate(),
n,
),
_ => panic!("Input type should be Array"),
};
let (mut res, len) = match check(&c.cs()[0]) {
Sort::Array(k, v, n) => (Array::default((*k).clone(), &v, n), n),
_ => panic!("Output type of rot should be Array"),
};
// calculate new rotation amount
let rot = *i % len;
@@ -1464,10 +1475,7 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
}
o => unimplemented!("eval: {:?}", o),
};
vs.insert(c.clone(), v.clone());
debug!("Eval {}\nAs {}", c, v);
v
}
}
/// Make an array from a sequence of terms.
@@ -1562,6 +1570,17 @@ impl PostOrderIter {
visited: TermSet::default(),
}
}
/// Make an iterator over the descendents of `roots`, stopping at `skips`.
pub fn from_roots_and_skips(roots: impl IntoIterator<Item = Term>, skips: TermSet) -> Self {
Self {
stack: roots
.into_iter()
.filter(|t| !skips.contains(t))
.map(|t| (false, t))
.collect(),
visited: skips,
}
}
}
impl std::iter::Iterator for PostOrderIter {
@@ -1709,6 +1728,47 @@ impl ComputationMetadata {
out
}
/// Get the interactive structure of the variables. See [InteractiveVars].
pub fn interactive_vars(&self) -> InteractiveVars {
let final_round = self.vars.values().map(|m| m.round).max().unwrap_or(0);
let mut instances = Vec::new();
let mut rounds = vec![RoundVars::default(); final_round as usize + 1];
for meta in self.vars.values() {
if meta.random {
// is this a challenge?, if so it must be public
assert!(meta.vis.is_none());
rounds[meta.round as usize].challenges.push(meta.term());
} else if meta.vis.is_none() {
// is it a public non-challenge? if so, it must be round 0
assert!(meta.round == 0);
instances.push(meta.term());
} else {
// this is a witness
rounds[meta.round as usize].witnesses.push(meta.term());
}
}
// If there no final challenges, distinguish the last round of witnesses
let final_witnesses = if rounds.last().unwrap().challenges.is_empty() {
rounds.pop().unwrap().witnesses
} else {
Vec::new()
};
let mut ret = InteractiveVars {
instances,
rounds,
final_witnesses,
};
// sort!
let cmp_name = |a: &Term, b: &Term| a.as_var_name().cmp(b.as_var_name());
ret.instances.sort_by(cmp_name);
for round in &mut ret.rounds {
round.witnesses.sort_by(cmp_name);
round.challenges.sort_by(cmp_name);
}
ret.final_witnesses.sort_by(cmp_name);
ret
}
/// Get a round after the rounds of these variables
pub fn future_round<'a, Q: Borrow<str> + 'a>(
&self,
@@ -1757,6 +1817,29 @@ impl ComputationMetadata {
}
}
/// A structured collection of variables that indicates the round structure: e.g., orderings,
/// challenges.
///
/// It represents the variables themselves as terms.
#[derive(Default)]
pub struct InteractiveVars {
/// Instance vars
pub instances: Vec<Term>,
/// Rounds
pub rounds: Vec<RoundVars>,
/// Final witnesses
pub final_witnesses: Vec<Term>,
}
/// Witnesses, followed by a challenge.
#[derive(Default, Clone)]
pub struct RoundVars {
/// witnesses
pub witnesses: Vec<Term>,
/// followed by challenges
pub challenges: Vec<Term>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
/// An IR computation.
pub struct Computation {

View File

@@ -123,6 +123,19 @@ impl PreComp {
self.recompute_inputs();
self
}
/// Reduce the precomputation to a single, step-less map.
pub fn flatten(self) -> FxHashMap<String, Term> {
let mut out: FxHashMap<String, Term> = Default::default();
let mut cache: TermMap<Term> = Default::default();
for (name, sort) in &self.sequence {
let term = extras::substitute_cache(self.outputs.get(name).unwrap(), &mut cache);
let var_term = leaf_term(Op::Var(name.clone(), sort.clone()));
out.insert(name.into(), term.clone());
cache.insert(var_term, term);
}
out
}
}
#[cfg(test)]

View File

@@ -288,6 +288,10 @@ impl<'src> IrInterp<'src> {
[Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))),
[Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))),
[Leaf(Ident, b"fp2fp"), a] => Ok(Op::FpToFp(self.usize(a))),
[Leaf(Ident, b"challenge"), name, field] => Ok(Op::PfChallenge(
self.ident_string(name),
FieldT::from(self.int(field)),
)),
[Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))),
[Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))),
[Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))),
@@ -688,6 +692,10 @@ impl<'src> IrInterp<'src> {
let inputs = self.var_decl_list(&tts[1]);
let outputs = self.var_decl_list(&tts[2]);
let tuple_term = self.term(&tts[3]);
assert!(
matches!(check(&tuple_term), Sort::Tuple(..)),
"precompute output term must be a tuple"
);
assert!(
outputs.len() == tuple_term.cs().len(),
"output list has {} items, tuple has {}",
@@ -958,4 +966,12 @@ mod test {
let c2 = parse_precompute(s.as_bytes());
assert_eq!(c, c2);
}
#[test]
fn challenge_roundtrip() {
let t = parse_term(b"(declare ((a bool) (b bool)) ((challenge hithere 17) a b))");
let s = serialize_term(&t);
let t2 = parse_term(s.as_bytes());
assert_eq!(t, t2);
}
}

View File

@@ -59,6 +59,7 @@ fn check_dependencies(t: &Term) -> Vec<Term> {
Op::IntNaryOp(_) => Vec::new(),
Op::IntBinPred(_) => Vec::new(),
Op::UbvToPf(_) => Vec::new(),
Op::PfChallenge(_, _) => Vec::new(),
Op::Select => vec![t.cs()[0].clone()],
Op::Store => vec![t.cs()[0].clone()],
Op::Tuple => t.cs().to_vec(),
@@ -128,6 +129,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
Op::IntNaryOp(_) => Ok(Sort::Int),
Op::IntBinPred(_) => Ok(Sort::Bool),
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
Op::PfChallenge(_, m) => Ok(Sort::Field(m.clone())),
Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v)| v.clone()),
Op::Store => Ok(get_ty(&t.cs()[0]).clone()),
Op::Tuple => Ok(Sort::Tuple(t.cs().iter().map(get_ty).cloned().collect())),
@@ -332,6 +334,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
.map(|a| a.clone())
}
(Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())),
(Op::PfChallenge(_, m), _) => Ok(Sort::Field(m.clone())),
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
(Op::IntNaryOp(_), a) => {
let ctx = "int nary op";

View File

@@ -1,31 +1,28 @@
//! Exporting our R1CS to bellman
use ::bellman::{
groth16::{
create_random_proof, generate_random_parameters, prepare_verifying_key, verify_proof,
Parameters, Proof, VerifyingKey,
},
Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable,
};
use bincode::{deserialize_from, serialize_into};
use ::bellman::{groth16, Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable};
use ff::{Field, PrimeField, PrimeFieldBits};
use fxhash::FxHashMap;
use gmp_mpfr_sys::gmp::limb_t;
use group::WnafGroup;
use log::debug;
use pairing::{Engine, MultiMillerLoop};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader};
use std::io::{BufRead, BufReader};
use std::marker::PhantomData;
use std::path::Path;
use std::str::FromStr;
use rug::integer::{IsPrime, Order};
use rug::Integer;
use super::*;
use super::proof;
use super::{wit_comp::StagedWitCompEvaluator, Lc, ProverData, Var, VarType, VerifierData};
use crate::ir::term::Value;
/// Convert a (rug) integer to a prime field element.
fn int_to_ff<F: PrimeField>(i: Integer) -> F {
pub(super) fn int_to_ff<F: PrimeField>(i: Integer) -> F {
let mut accumulator = F::from(0);
let limb_bits = (std::mem::size_of::<limb_t>() as u64) << 3;
let limb_base = F::from(2).pow_vartime([limb_bits]);
@@ -39,8 +36,8 @@ fn int_to_ff<F: PrimeField>(i: Integer) -> F {
/// Convert one our our linear combinations to a bellman linear combination.
/// Takes a zero linear combination. We could build it locally, but bellman provides one, so...
fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
vars: &HashMap<usize, Variable>,
pub(super) fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
vars: &HashMap<Var, Variable>,
lc: &Lc,
zero_lc: LinearCombination<F>,
) -> LinearCombination<F> {
@@ -59,7 +56,7 @@ fn lc_to_bellman<F: PrimeField, CS: ConstraintSystem<F>>(
}
// hmmm... this should work essentially all the time, I think
fn get_modulus<F: Field + PrimeField>() -> Integer {
pub(super) fn get_modulus<F: Field + PrimeField>() -> Integer {
let neg_1_f = -F::one();
let p_lsf: Integer = Integer::from_digits(neg_1_f.to_repr().as_ref(), Order::Lsf) + 1;
let p_msf: Integer = Integer::from_digits(neg_1_f.to_repr().as_ref(), Order::Msf) + 1;
@@ -76,7 +73,7 @@ fn get_modulus<F: Field + PrimeField>() -> Integer {
///
/// Optionally contains a variable value map. This must be populated to use the
/// bellman prover.
pub struct SynthInput<'a>(&'a R1cs<String>, &'a Option<FxHashMap<String, Value>>);
pub struct SynthInput<'a>(&'a ProverData, Option<&'a FxHashMap<String, Value>>);
impl<'a, F: PrimeField> Circuit<F> for SynthInput<'a> {
#[track_caller]
@@ -86,55 +83,53 @@ impl<'a, F: PrimeField> Circuit<F> for SynthInput<'a> {
{
let f_mod = get_modulus::<F>();
assert_eq!(
self.0.modulus.modulus(),
self.0.r1cs.field.modulus(),
&f_mod,
"\nR1CS has modulus \n{},\n but Bellman CS expects \n{}",
self.0.modulus,
self.0.r1cs.field,
f_mod
);
let mut uses = HashMap::with_capacity(self.0.next_idx);
for (a, b, c) in self.0.constraints.iter() {
[a, b, c].iter().for_each(|y| {
y.monomials.keys().for_each(|k| {
uses.get_mut(k)
.map(|i| {
*i += 1;
})
.or_else(|| {
uses.insert(*k, 1);
None
});
let mut vars = HashMap::with_capacity(self.0.r1cs.vars.len());
let values: Option<Vec<_>> = self.1.map(|values| {
let mut evaluator = StagedWitCompEvaluator::new(&self.0.precompute);
let mut ffs = Vec::new();
ffs.extend(evaluator.eval_stage(values.clone()).into_iter().cloned());
ffs.extend(
evaluator
.eval_stage(Default::default())
.into_iter()
.cloned(),
);
ffs
});
for (i, var) in self.0.r1cs.vars.iter().copied().enumerate() {
assert!(
!matches!(var.ty(), VarType::CWit),
"Bellman doesn't support committed witnesses"
);
assert!(
!matches!(var.ty(), VarType::RoundWit | VarType::Chall),
"Bellman doesn't support rounds"
);
let public = matches!(var.ty(), VarType::Inst);
let name_f = || format!("{var:?}");
let val_f = || {
Ok({
let i_val = &values.as_ref().expect("missing values")[i];
let ff_val = int_to_ff(i_val.as_pf().into());
debug!("value : {var:?} -> {ff_val:?} ({i_val})");
ff_val
})
});
};
debug!("var: {:?}, public: {}", var, public);
let v = if public {
cs.alloc_input(name_f, val_f)?
} else {
cs.alloc(name_f, val_f)?
};
vars.insert(var, v);
}
let mut vars = HashMap::with_capacity(self.0.next_idx);
for i in 0..self.0.next_idx {
if let Some(s) = self.0.idxs_signals.get(&i) {
//for (_i, s) in self.0.idxs_signals.get() {
let public = self.0.public_idxs.contains(&i);
if uses.get(&i).is_some() || public {
let name_f = || s.to_string();
let val_f = || {
Ok({
let i_val = self.1.as_ref().expect("missing values").get(s).unwrap();
let ff_val = int_to_ff(i_val.as_pf().into());
debug!("value : {} -> {:?} ({})", s, ff_val, i_val);
ff_val
})
};
debug!("var: {}, public: {}", s, public);
let v = if public {
cs.alloc_input(name_f, val_f)?
} else {
cs.alloc(name_f, val_f)?
};
vars.insert(i, v);
} else {
debug!("drop dead var: {}", s);
}
}
}
for (i, (a, b, c)) in self.0.constraints.iter().enumerate() {
for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() {
cs.enforce(
|| format!("con{i}"),
|z| lc_to_bellman::<F, CS>(&vars, a, z),
@@ -145,7 +140,7 @@ impl<'a, F: PrimeField> Circuit<F> for SynthInput<'a> {
debug!(
"done with synth: {} vars {} cs",
vars.len(),
self.0.constraints.len()
self.0.r1cs.constraints.len()
);
Ok(())
}
@@ -168,12 +163,6 @@ mod serde_pk {
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Serialize)]
pub struct SerPk<'a, E: Engine>(#[serde(with = "self")] pub &'a Parameters<E>);
#[derive(Deserialize)]
pub struct DePk<E: Engine>(#[serde(with = "self")] pub Parameters<E>);
pub fn serialize<S: Serializer, E: Engine>(
p: &Parameters<E>,
ser: S,
@@ -196,12 +185,6 @@ mod serde_vk {
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Serialize)]
pub struct SerVk<'a, E: Engine>(#[serde(with = "self")] pub &'a VerifyingKey<E>);
#[derive(Deserialize)]
pub struct DeVk<E: Engine>(#[serde(with = "self")] pub VerifyingKey<E>);
pub fn serialize<S: Serializer, E: Engine>(
p: &VerifyingKey<E>,
ser: S,
@@ -219,116 +202,80 @@ mod serde_vk {
}
}
/// Given
/// * a proving-key path,
/// * a verifying-key path,
/// * prover data, and
/// * verifier data
/// generate parameters and write them and the data to files at those paths.
pub fn gen_params<E: Engine, P1: AsRef<Path>, P2: AsRef<Path>>(
pk_path: P1,
vk_path: P2,
p_data: &ProverData,
v_data: &VerifierData,
) -> io::Result<()>
mod serde_pf {
use bellman::groth16::Proof;
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer, E: Engine>(p: &Proof<E>, ser: S) -> Result<S::Ok, S::Error> {
let mut bs: Vec<u8> = Vec::new();
p.write(&mut bs).unwrap();
serde_bytes::ByteBuf::from(bs).serialize(ser)
}
pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(de: D) -> Result<Proof<E>, D::Error> {
let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?;
Ok(Proof::read(&**bs).unwrap())
}
}
/// The [::bellman] implementation of Groth16.
pub struct Bellman<E: Engine>(PhantomData<E>);
/// The pk for [Bellman]
#[derive(Serialize, Deserialize)]
pub struct ProvingKey<E: Engine>(
ProverData,
#[serde(with = "serde_pk")] groth16::Parameters<E>,
);
/// The vk for [Bellman]
#[derive(Serialize, Deserialize)]
pub struct VerifyingKey<E: Engine>(
VerifierData,
#[serde(with = "serde_vk")] groth16::VerifyingKey<E>,
);
/// The proof for [Bellman]
#[derive(Serialize, Deserialize)]
pub struct Proof<E: Engine>(#[serde(with = "serde_pf")] groth16::Proof<E>);
impl<E: Engine> proof::ProofSystem for Bellman<E>
where
E: MultiMillerLoop,
E::G1: WnafGroup,
E::G2: WnafGroup,
{
let rng = &mut rand::thread_rng();
let p = generate_random_parameters::<E, _, _>(SynthInput(&p_data.r1cs, &None), rng).unwrap();
write_prover_key_and_data(pk_path, &p, p_data)?;
write_verifier_key_and_data(vk_path, &p.vk, v_data)?;
Ok(())
}
fn write_prover_key_and_data<P: AsRef<Path>, E: Engine>(
path: P,
params: &Parameters<E>,
data: &ProverData,
) -> io::Result<()> {
let mut file = File::create(path)?;
serialize_into(&mut file, &(serde_pk::SerPk(params), &data)).unwrap();
Ok(())
}
fn read_prover_key_and_data<P: AsRef<Path>, E: Engine>(
path: P,
) -> io::Result<(Parameters<E>, ProverData)> {
let mut file = File::open(path)?;
let (serde_pk::DePk(pk), data): (_, ProverData) = deserialize_from(&mut file).unwrap();
Ok((pk, data))
}
fn write_verifier_key_and_data<P: AsRef<Path>, E: Engine>(
path: P,
key: &VerifyingKey<E>,
data: &VerifierData,
) -> io::Result<()> {
let mut file = File::create(path)?;
serialize_into(&mut file, &(serde_vk::SerVk(key), &data)).unwrap();
Ok(())
}
fn read_verifier_key_and_data<P: AsRef<Path>, E: Engine>(
path: P,
) -> io::Result<(VerifyingKey<E>, VerifierData)> {
let mut file = File::open(path)?;
let (serde_vk::DeVk(vk), data): (_, VerifierData) = deserialize_from(&mut file).unwrap();
Ok((vk, data))
}
/// Given
/// * a proving-key path,
/// * a proof path, and
/// * a prover input map
/// generate a random proof and writes it to the path
pub fn prove<E: Engine, P1: AsRef<Path>, P2: AsRef<Path>>(
pk_path: P1,
pf_path: P2,
inputs_map: &FxHashMap<String, Value>,
) -> io::Result<()>
where
E::Fr: PrimeFieldBits,
{
let (pk, prover_data) = read_prover_key_and_data::<_, E>(pk_path)?;
let rng = &mut rand::thread_rng();
for (input, sort) in &prover_data.precompute_inputs {
let value = inputs_map
.get(input)
.unwrap_or_else(|| panic!("No input for {}", input));
let sort2 = value.sort();
assert_eq!(
sort, &sort2,
"Sort mismatch for {input}. Expected\n\t{sort} but got\n\t{sort2}",
);
}
let new_map = prover_data.precompute.eval(inputs_map);
prover_data.r1cs.check_all(&new_map);
let pf = create_random_proof(SynthInput(&prover_data.r1cs, &Some(new_map)), &pk, rng).unwrap();
let mut pf_file = File::create(pf_path)?;
pf.write(&mut pf_file)?;
Ok(())
}
type VerifyingKey = VerifyingKey<E>;
/// Given
/// * a verifying-key path,
/// * a proof path,
/// * and a verifier input map
/// checks the proof at that path
pub fn verify<E: MultiMillerLoop, P1: AsRef<Path>, P2: AsRef<Path>>(
vk_path: P1,
pf_path: P2,
inputs_map: &FxHashMap<String, Value>,
) -> io::Result<()> {
let (vk, verifier_data) = read_verifier_key_and_data::<_, E>(vk_path)?;
let pvk = prepare_verifying_key(&vk);
let inputs = verifier_data.eval(inputs_map);
let inputs_as_ff: Vec<E::Fr> = inputs.into_iter().map(int_to_ff).collect();
let mut pf_file = File::open(pf_path).unwrap();
let pf = Proof::read(&mut pf_file).unwrap();
verify_proof(&pvk, &pf, &inputs_as_ff).unwrap();
Ok(())
type ProvingKey = ProvingKey<E>;
type Proof = Proof<E>;
fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey) {
let rng = &mut rand::thread_rng();
let params =
groth16::generate_random_parameters::<E, _, _>(SynthInput(&p_data, None), rng).unwrap();
let v_params = params.vk.clone();
(ProvingKey(p_data, params), VerifyingKey(v_data, v_params))
}
fn prove(pk: &Self::ProvingKey, witness: &FxHashMap<String, Value>) -> Self::Proof {
let rng = &mut rand::thread_rng();
pk.0.check_all(witness);
Proof(groth16::create_random_proof(SynthInput(&pk.0, Some(witness)), &pk.1, rng).unwrap())
}
fn verify(vk: &Self::VerifyingKey, inst: &FxHashMap<String, Value>, pf: &Self::Proof) -> bool {
let pvk = groth16::prepare_verifying_key(&vk.1);
let r1cs_inst_map = vk.0.eval(inst);
let r1cs_inst: Vec<E::Fr> = r1cs_inst_map
.into_iter()
.map(|i| int_to_ff(i.i()))
.collect();
groth16::verify_proof(&pvk, &pf.0, &r1cs_inst).is_ok()
}
}
#[cfg(test)]

356
src/target/r1cs/mirage.rs Normal file
View File

@@ -0,0 +1,356 @@
//! Exporting our R1CS to bellman
use ::bellman::{
cc::{CcCircuit, CcConstraintSystem},
mirage, SynthesisError,
};
use ff::{PrimeField, PrimeFieldBits};
use fxhash::FxHashMap;
use group::WnafGroup;
use log::debug;
use pairing::{Engine, MultiMillerLoop};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::marker::PhantomData;
use std::path::Path;
use std::str::FromStr;
use rug::Integer;
use super::proof;
use super::{wit_comp::StagedWitCompEvaluator, ProverData, VarType, VerifierData};
use crate::ir::term::Value;
use super::bellman::{get_modulus, int_to_ff, lc_to_bellman};
fn ff_to_int<F: PrimeFieldBits>(f: F) -> Integer {
let mut buffer = vec![];
use std::io::Read;
f.to_le_bits()
.as_bitslice()
.read_to_end(&mut buffer)
.unwrap();
Integer::from_digits(&buffer, rug::integer::Order::Lsf)
}
/// A synthesizable bellman circuit.
///
/// Optionally contains a variable value map. This must be populated to use the
/// bellman prover.
pub struct SynthInput<'a>(&'a ProverData, Option<&'a FxHashMap<String, Value>>);
impl<'a, F: PrimeField + PrimeFieldBits> CcCircuit<F> for SynthInput<'a> {
#[track_caller]
fn synthesize<CS>(self, cs: &mut CS) -> std::result::Result<(), SynthesisError>
where
CS: CcConstraintSystem<F>,
{
let f_mod = get_modulus::<F>();
assert_eq!(
self.0.r1cs.field.modulus(),
&f_mod,
"\nR1CS has modulus \n{},\n but mirage CS expects \n{}",
self.0.r1cs.field,
f_mod
);
let mut vars = HashMap::with_capacity(self.0.r1cs.vars.len());
// (assignment values, evaluator, next evaluator inputs)
let mut wit_comp: Option<(
Vec<Value>,
StagedWitCompEvaluator<'a>,
FxHashMap<String, Value>,
)> = self.1.map(|inputs| {
(
Vec::new(),
StagedWitCompEvaluator::new(&self.0.precompute),
inputs.clone(),
)
});
let mut var_idx = 0;
let num_stages = self.0.precompute.stage_sizes().count();
for (i, num_vars) in self.0.precompute.stage_sizes().enumerate() {
if let Some((ref mut var_values, ref mut evaluator, ref mut inputs)) = wit_comp.as_mut()
{
var_values.extend(
evaluator
.eval_stage(std::mem::take(inputs))
.into_iter()
.cloned(),
);
}
let num_challs = if i + 1 < num_stages {
self.0.precompute.num_stage_inputs(i + 1)
} else {
0
};
for j in 0..(num_vars + num_challs) {
let var = self.0.r1cs.vars[var_idx];
let name_f = || format!("{var:?}");
let val_f = || {
Ok({
let i_val = &wit_comp.as_ref().expect("missing values").0[var_idx];
let ff_val = int_to_ff(i_val.as_pf().into());
debug!("value : {var:?} -> {ff_val:?} ({i_val})");
ff_val
})
};
let v = match var.ty() {
VarType::Inst => cs.alloc_input(name_f, val_f)?,
VarType::RoundWit => cs.alloc(name_f, val_f)?,
VarType::FinalWit => cs.alloc(name_f, val_f)?,
VarType::Chall => {
let (v, val) = cs.alloc_random(name_f)?;
if let Some((ref mut values, _, ref mut inputs)) = wit_comp.as_mut() {
let val =
Value::Field(self.0.r1cs.field.new_v(ff_to_int(val.unwrap())));
values.push(val.clone());
let name = self.0.r1cs.names.get(&var).unwrap();
inputs.insert(name.to_owned(), val);
}
v
}
VarType::CWit => unimplemented!(),
};
vars.insert(var, v);
var_idx += 1;
if j + 1 == num_vars && num_challs > 0 {
cs.end_aux_block(|| format!("block {}", i - 1))?;
}
}
}
for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() {
cs.enforce(
|| format!("con{i}"),
|z| lc_to_bellman::<F, CS>(&vars, a, z),
|z| lc_to_bellman::<F, CS>(&vars, b, z),
|z| lc_to_bellman::<F, CS>(&vars, c, z),
);
}
debug!(
"done with synth: {} vars {} cs",
vars.len(),
self.0.r1cs.constraints.len()
);
Ok(())
}
fn num_aux_blocks(&self) -> usize {
self.0.precompute.stage_sizes().count() - 2
}
}
/// Convert a (rug) integer to a prime field element.
pub fn parse_instance<P: AsRef<Path>, F: PrimeField>(path: P) -> Vec<F> {
let f = BufReader::new(File::open(path).unwrap());
f.lines()
.map(|line| {
let s = line.unwrap();
let i = Integer::from_str(s.trim()).unwrap();
int_to_ff(i)
})
.collect()
}
mod serde_pk {
use bellman::mirage::Parameters;
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer, E: Engine>(
p: &Parameters<E>,
ser: S,
) -> Result<S::Ok, S::Error> {
let mut bs: Vec<u8> = Vec::new();
p.write(&mut bs).unwrap();
serde_bytes::ByteBuf::from(bs).serialize(ser)
}
pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(
de: D,
) -> Result<Parameters<E>, D::Error> {
let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?;
Ok(Parameters::read(&**bs, false).unwrap())
}
}
mod serde_vk {
use bellman::mirage::VerifyingKey;
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer, E: Engine>(
p: &VerifyingKey<E>,
ser: S,
) -> Result<S::Ok, S::Error> {
let mut bs: Vec<u8> = Vec::new();
p.write(&mut bs).unwrap();
serde_bytes::ByteBuf::from(bs).serialize(ser)
}
pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(
de: D,
) -> Result<VerifyingKey<E>, D::Error> {
let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?;
Ok(VerifyingKey::read(&**bs).unwrap())
}
}
mod serde_pf {
use bellman::mirage::Proof;
use pairing::Engine;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer, E: Engine>(p: &Proof<E>, ser: S) -> Result<S::Ok, S::Error> {
let mut bs: Vec<u8> = Vec::new();
p.write(&mut bs).unwrap();
serde_bytes::ByteBuf::from(bs).serialize(ser)
}
pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(de: D) -> Result<Proof<E>, D::Error> {
let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?;
Ok(Proof::read(&**bs).unwrap())
}
}
/// The [::bellman] implementation of Groth16.
pub struct Mirage<E: Engine>(PhantomData<E>);
/// The pk for [mirage]
#[derive(Serialize, Deserialize)]
pub struct ProvingKey<E: Engine>(
ProverData,
#[serde(with = "serde_pk")] mirage::Parameters<E>,
);
/// The vk for [mirage]
#[derive(Serialize, Deserialize)]
pub struct VerifyingKey<E: Engine>(
VerifierData,
#[serde(with = "serde_vk")] mirage::VerifyingKey<E>,
);
/// The proof for [mirage]
#[derive(Serialize, Deserialize)]
pub struct Proof<E: Engine>(#[serde(with = "serde_pf")] mirage::Proof<E>);
impl<E: Engine> proof::ProofSystem for Mirage<E>
where
E: MultiMillerLoop,
E::G1: WnafGroup,
E::G2: WnafGroup,
E::Fr: PrimeFieldBits,
{
type VerifyingKey = VerifyingKey<E>;
type ProvingKey = ProvingKey<E>;
type Proof = Proof<E>;
fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey) {
let rng = &mut rand::thread_rng();
let params =
mirage::generate_random_parameters::<E, _, _>(SynthInput(&p_data, None), rng).unwrap();
let v_params = params.vk.clone();
(ProvingKey(p_data, params), VerifyingKey(v_data, v_params))
}
fn prove(pk: &Self::ProvingKey, witness: &FxHashMap<String, Value>) -> Self::Proof {
let rng = &mut rand::thread_rng();
pk.0.check_all(witness);
Proof(mirage::create_random_proof(SynthInput(&pk.0, Some(witness)), &pk.1, rng).unwrap())
}
fn verify(vk: &Self::VerifyingKey, inst: &FxHashMap<String, Value>, pf: &Self::Proof) -> bool {
let pvk = mirage::prepare_verifying_key(&vk.1);
let r1cs_inst_map = vk.0.eval(inst);
let r1cs_inst: Vec<E::Fr> = r1cs_inst_map
.into_iter()
.map(|i| int_to_ff(i.i()))
.collect();
mirage::verify_proof(&pvk, &pf.0, &r1cs_inst).is_ok()
}
}
#[cfg(test)]
mod test {
use super::*;
use bls12_381::Scalar;
use quickcheck::{Arbitrary, Gen};
use quickcheck_macros::quickcheck;
use std::io::Write;
#[derive(Clone, Debug)]
struct BlsScalar(Integer);
impl Arbitrary for BlsScalar {
fn arbitrary(g: &mut Gen) -> Self {
let mut rug_rng = rug::rand::RandState::new_mersenne_twister();
rug_rng.seed(&Integer::from(u32::arbitrary(g)));
let modulus = Integer::from(
Integer::parse_radix(
"73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001",
16,
)
.unwrap(),
);
let i = Integer::from(modulus.random_below_ref(&mut rug_rng));
BlsScalar(i)
}
}
#[quickcheck]
fn int_to_ff_random(BlsScalar(i): BlsScalar) -> bool {
let by_fn = int_to_ff::<Scalar>(i.clone());
let by_str = Scalar::from_str_vartime(&format!("{i}")).unwrap();
by_fn == by_str
}
#[quickcheck]
fn roundtrip_random(BlsScalar(i): BlsScalar) -> bool {
let ff = int_to_ff::<Scalar>(i.clone());
let i2 = ff_to_int(ff);
i == i2
}
fn convert(i: Integer) {
let by_fn = int_to_ff::<Scalar>(i.clone());
let by_str = Scalar::from_str_vartime(&format!("{i}")).unwrap();
assert_eq!(by_fn, by_str);
}
#[test]
fn neg_one() {
let modulus = Integer::from(
Integer::parse_radix(
"73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001",
16,
)
.unwrap(),
);
convert(modulus - 1);
}
#[test]
fn zero() {
convert(Integer::from(0));
}
#[test]
fn one() {
convert(Integer::from(1));
}
#[test]
fn parse() {
let path = format!("{}/instance", std::env::temp_dir().to_str().unwrap());
{
let mut f = File::create(&path).unwrap();
write!(f, "5\n6").unwrap();
}
let i = parse_instance::<_, Scalar>(&path);
assert_eq!(i[0], Scalar::from(5));
assert_eq!(i[1], Scalar::from(6));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +1,24 @@
//! Optimizations over R1CS
use super::*;
use crate::cfg::CircCfg;
use crate::util::once::OnceQueue;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use log::debug;
struct LinReducer<S: Eq + Hash> {
r1cs: R1cs<S>,
uses: HashMap<usize, HashSet<usize>>,
use std::collections::hash_map::Entry;
use super::*;
use crate::cfg::CircCfg;
use crate::util::once::OnceQueue;
struct LinReducer {
r1cs: R1cs,
uses: HashMap<Var, HashSet<usize>>,
queue: OnceQueue<usize>,
/// The maximum size LC (number of non-constant monomials)
/// that will be used for propagation
lc_size_thresh: usize,
}
impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
fn new(mut r1cs: R1cs<S>, lc_size_thresh: usize) -> Self {
impl LinReducer {
fn new(mut r1cs: R1cs, lc_size_thresh: usize) -> Self {
let uses = LinReducer::gen_uses(&r1cs);
let queue = (0..r1cs.constraints.len()).collect::<OnceQueue<usize>>();
for c in &mut r1cs.constraints {
@@ -30,9 +33,9 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
}
// generate a new uses hash
fn gen_uses(r1cs: &R1cs<S>) -> HashMap<usize, HashSet<usize>> {
let mut uses: HashMap<usize, HashSet<usize>> =
HashMap::with_capacity_and_hasher(r1cs.next_idx, Default::default());
fn gen_uses(r1cs: &R1cs) -> HashMap<Var, HashSet<usize>> {
let mut uses: HashMap<Var, HashSet<usize>> =
HashMap::with_capacity_and_hasher(r1cs.num_vars(), Default::default());
let mut add = |i: usize, y: &Lc| {
for x in y.monomials.keys() {
uses.get_mut(x).map(|m| m.insert(i)).or_else(|| {
@@ -54,7 +57,7 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
/// Substitute `val` for `var` in constraint with id `con_id`.
/// Updates uses conservatively (not precisely)
/// Returns whether a sub happened.
fn sub_in(&mut self, var: usize, val: &Lc, con_id: usize) -> bool {
fn sub_in(&mut self, var: Var, val: &Lc, con_id: usize) -> bool {
let (a, b, c) = &mut self.r1cs.constraints[con_id];
let uses = &mut self.uses;
let mut do_in = |a: &mut Lc| {
@@ -112,15 +115,13 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
self.r1cs.constraints[i].2.clear();
}
fn run(mut self) -> R1cs<S> {
fn run(mut self) -> R1cs {
while let Some(con_id) = self.queue.pop() {
if let Some((var, lc)) =
as_linear_sub(&self.r1cs.constraints[con_id], &self.r1cs.public_idxs)
{
if let Some((var, lc)) = as_linear_sub(&self.r1cs.constraints[con_id], &self.r1cs) {
if lc.monomials.len() < self.lc_size_thresh {
debug!(
"Elim: {} -> {}",
self.r1cs.idxs_signals.get(&var).unwrap(),
self.r1cs.idx_to_sig.get_fwd(&var).unwrap(),
self.r1cs.format_lc(&lc)
);
self.clear_constraint(con_id);
@@ -141,10 +142,10 @@ impl<S: Eq + Hash + Display + Clone> LinReducer<S> {
}
}
fn as_linear_sub((a, b, c): &(Lc, Lc, Lc), public: &HashSet<usize>) -> Option<(usize, Lc)> {
fn as_linear_sub((a, b, c): &(Lc, Lc, Lc), r1cs: &R1cs) -> Option<(Var, Lc)> {
if a.is_zero() || b.is_zero() {
for i in c.monomials.keys() {
if !public.contains(i) {
if r1cs.can_eliminate(*i) {
let mut lc = c.clone();
let v = lc.monomials.remove(i).unwrap();
lc *= v.recip();
@@ -184,7 +185,7 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool {
///
/// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used
/// for propagation. `None` means no size limit.
pub fn reduce_linearities<S: Eq + Hash + Clone + Display>(r1cs: R1cs<S>, cfg: &CircCfg) -> R1cs<S> {
pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs {
LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run()
}
@@ -199,7 +200,7 @@ mod test {
use rand::SeedableRng;
#[derive(Clone, Debug)]
pub struct SatR1cs(R1cs<String>, FxHashMap<String, Value>);
pub struct SatR1cs(R1cs, FxHashMap<String, Value>);
impl Arbitrary for SatR1cs {
fn arbitrary(g: &mut Gen) -> Self {
@@ -208,14 +209,18 @@ mod test {
let n_vars = g.size() + 1;
let vars: Vec<_> = (0..n_vars).map(|i| format!("v{i}")).collect();
let mut values: FxHashMap<String, Value> = Default::default();
let mut r1cs = R1cs::new(field.clone());
let mut var_values: FxHashMap<Var, FieldV> = Default::default();
let mut r1cs = R1cs::new(field.clone(), Default::default());
let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g));
for v in &vars {
values.insert(v.clone(), Value::Field(field.random_v(&mut rng)));
r1cs.add_signal(
let var = r1cs.add_var(
v.clone(),
leaf_term(Op::Var(v.clone(), Sort::Field(field.clone()))),
VarType::FinalWit,
);
let val = field.random_v(&mut rng);
var_values.insert(var, val.clone());
values.insert(v.into(), Value::Field(val));
}
for _ in 0..(2 * g.size()) {
let ac: isize = <isize as Arbitrary>::arbitrary(g) % m;
@@ -236,7 +241,8 @@ mod test {
} else {
r1cs.zero()
} + cc;
let off = r1cs.eval(&a, &values) * r1cs.eval(&b, &values) - r1cs.eval(&c, &values);
let off = r1cs.eval(&a, &var_values) * r1cs.eval(&b, &var_values)
- r1cs.eval(&c, &var_values);
c += &off;
r1cs.constraint(a, b, c);
}

331
src/target/r1cs/proof.rs Normal file
View File

@@ -0,0 +1,331 @@
//! A trait for CirC-compatible proofs
use std::fs::File;
use std::path::Path;
use bincode::{deserialize_from, serialize_into};
use fxhash::FxHashMap as HashMap;
use serde::{Deserialize, Serialize};
use super::{ProverData, VerifierData};
use crate::ir::term::text::parse_value_map;
use crate::ir::term::Value;
/// A trait for CirC-compatible proofs
pub trait ProofSystem {
/// A verifying key
type VerifyingKey: Serialize + for<'a> Deserialize<'a>;
/// A proving key
type ProvingKey: Serialize + for<'a> Deserialize<'a>;
/// A proof
type Proof: Serialize + for<'a> Deserialize<'a>;
/// Setup
fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey);
/// Proving
fn prove(pk: &Self::ProvingKey, witness: &HashMap<String, Value>) -> Self::Proof;
/// Verification
fn verify(vk: &Self::VerifyingKey, inst: &HashMap<String, Value>, pf: &Self::Proof) -> bool;
/// Setup to files
fn setup_fs<P1, P2>(
p_data: ProverData,
v_data: VerifierData,
pk_path: P1,
vk_path: P2,
) -> std::io::Result<()>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
{
let (pk, vk) = Self::setup(p_data, v_data);
let mut pk_file = File::create(pk_path)?;
let mut vk_file = File::create(vk_path)?;
serialize_into(&mut pk_file, &pk).unwrap();
serialize_into(&mut vk_file, &vk).unwrap();
Ok(())
}
/// Prove to/from files
fn prove_fs<P1, P2>(pk_path: P1, witness_path: P2, pf_path: P2) -> std::io::Result<()>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
{
let pk_file = File::open(pk_path)?;
let mut pf_file = File::create(pf_path)?;
let witness_bytes = std::fs::read(witness_path)?;
let witness = parse_value_map(&witness_bytes);
let pk: Self::ProvingKey = deserialize_from(pk_file).unwrap();
let pf = Self::prove(&pk, &witness);
serialize_into(&mut pf_file, &pf).unwrap();
Ok(())
}
/// Verify from files
fn verify_fs<P1, P2>(vk_path: P1, instance_path: P2, pf_path: P2) -> std::io::Result<bool>
where
P1: AsRef<Path>,
P2: AsRef<Path>,
{
let vk_file = File::open(vk_path)?;
let pf_file = File::open(pf_path)?;
let instance_bytes = std::fs::read(instance_path)?;
let instance = parse_value_map(&instance_bytes);
let vk: Self::VerifyingKey = deserialize_from(vk_file).unwrap();
let pf: Self::Proof = deserialize_from(pf_file).unwrap();
Ok(Self::verify(&vk, &instance, &pf))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::cfg::CircCfg;
use crate::ir::term::*;
use crate::target::r1cs;
#[allow(dead_code)]
fn test_setup_prove_verify<PS: ProofSystem>(
cs: Computation,
p_input: HashMap<String, Value>,
v_input: HashMap<String, Value>,
) {
let cfg = CircCfg::default();
let r1cs = r1cs::trans::to_r1cs(&cs, &cfg);
let (p_data, v_data) = r1cs.finalize(&cs);
let (pk, vk) = PS::setup(p_data, v_data);
let pf = PS::prove(&pk, &p_input);
assert!(PS::verify(&vk, &v_input, &pf));
}
#[cfg(feature = "bellman")]
mod mirage {
use super::super::super::mirage::Mirage;
use super::*;
#[test]
fn bool_np() {
let c = text::parse_computation(
b"
(computation
(metadata
(parties P)
(inputs (a bool (party 0)) (b bool (party 0)) (return bool))
)
(precompute
((a bool) (b bool))
((return bool))
(tuple (and a b))
)
(= (and a b) return)
)",
);
let p_input = text::parse_value_map(
b"
(let (
(a true)
(b true)
) false; ignored
)",
);
let v_input = text::parse_value_map(
b"
(let (
(return true)
) false; ignored
)",
);
test_setup_prove_verify::<Mirage<bls12_381::Bls12>>(c, p_input, v_input);
}
#[test]
fn rand_perm() {
let c = text::parse_computation(
b"
(computation
(metadata
(parties P)
(inputs
(a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
)
)
(precompute () () (tuple))
(=
(* (+ a0 c) (+ a1 c) (+ a2 c))
(* (+ b0 c) (+ b1 c) (+ b2 c))
)
)",
);
let p_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(a0 #f1)
(a1 #f-1)
(a2 #f4)
(b0 #f-1)
(b1 #f1)
(b2 #f4)
) false))");
let v_input = text::parse_value_map(
b"
(let (
) false; ignored
)",
);
test_setup_prove_verify::<Mirage<bls12_381::Bls12>>(c, p_input, v_input);
}
#[test]
fn rand_double_perm() {
let c = text::parse_computation(
b"
(computation
(metadata
(parties P)
(inputs
(a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
(d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
)
)
(precompute () () (tuple))
(and
(=
(* (+ a0 c) (+ a1 c) (+ a2 c))
(* (+ b0 c) (+ b1 c) (+ b2 c))
)
(=
(* (+ a0 d) (+ a1 d) (+ a2 d))
(* (+ b0 d) (+ b1 d) (+ b2 d))
)
)
)",
);
let p_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(a0 #f1)
(a1 #f-1)
(a2 #f4)
(b0 #f-1)
(b1 #f1)
(b2 #f4)
) false))");
let v_input = text::parse_value_map(
b"
(let (
) false; ignored
)",
);
test_setup_prove_verify::<Mirage<bls12_381::Bls12>>(c, p_input, v_input);
}
#[test]
fn rand_double_perm_inst() {
let c = text::parse_computation(
b"
(computation
(metadata
(parties P)
(inputs
(a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513))
(a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513))
(a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513))
(b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
(d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
)
)
(precompute () () (tuple))
(and
(=
(* (+ a0 c) (+ a1 c) (+ a2 c))
(* (+ b0 c) (+ b1 c) (+ b2 c))
)
(=
(* (+ a0 d) (+ a1 d) (+ a2 d))
(* (+ b0 d) (+ b1 d) (+ b2 d))
)
)
)",
);
let p_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(a0 #f1)
(a1 #f-1)
(a2 #f4)
(b0 #f-1)
(b1 #f1)
(b2 #f4)
) false))");
let v_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(a0 #f1)
(a1 #f-1)
(a2 #f4)
) false; ignored
))",
);
test_setup_prove_verify::<Mirage<bls12_381::Bls12>>(c, p_input, v_input);
}
#[test]
fn precomp_with_chall() {
let c = text::parse_computation(
b"
(computation
(metadata
(parties P)
(inputs
(a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0))
(ha (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0) (round 1))
(d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random))
)
)
(precompute (
(a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513))
) (
(ha (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513))
) (tuple
(* a0 d)
))
(=
ha
(* a0 d)
)
)",
);
let p_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(a0 #f1)
) false))");
let v_input = text::parse_value_map(
b"
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
) false; ignored
))",
);
test_setup_prove_verify::<Mirage<bls12_381::Bls12>>(c, p_input, v_input);
}
}
}

View File

@@ -1,4 +1,5 @@
//! Export circ R1cs to Spartan
use crate::target::r1cs::wit_comp::StagedWitCompEvaluator;
use crate::target::r1cs::*;
use bincode::{deserialize_from, serialize_into};
use curve25519_dalek::scalar::Scalar;
@@ -55,7 +56,7 @@ pub fn verify<P: AsRef<Path>>(
let mut inp = Vec::new();
for v in &values {
let scalar = int_to_scalar(v);
let scalar = int_to_scalar(&v.i());
inp.push(scalar.to_bytes());
}
let inputs = InputsAssignment::new(&inp).unwrap();
@@ -78,11 +79,11 @@ pub fn r1cs_to_spartan(
// spartan format mapper: CirC -> Spartan
let mut wit = Vec::new();
let mut inp = Vec::new();
let mut trans: HashMap<usize, usize> = HashMap::default(); // Circ -> spartan ids
let mut itrans: HashMap<usize, usize> = HashMap::default(); // Circ -> spartan ids
let mut trans: HashMap<Var, usize> = HashMap::default(); // Circ -> spartan ids
let mut itrans: HashMap<usize, Var> = HashMap::default(); // spartan ids -> Circ
// check modulus
let f_mod = prover_data.r1cs.modulus.modulus();
let f_mod = prover_data.r1cs.field.modulus();
let s_mod = Integer::from_str_radix(
"7237005577332262213973186563042994240857116359379907606001950938285454250989",
10,
@@ -95,37 +96,22 @@ pub fn r1cs_to_spartan(
let values = eval_inputs(inputs_map, prover_data);
let pf_input_order: Vec<usize> = (0..prover_data.r1cs.next_idx)
.filter(|i| prover_data.r1cs.public_idxs.contains(i))
.collect();
assert_eq!(values.len(), prover_data.r1cs.vars.len());
for idx in &pf_input_order {
let sig = prover_data.r1cs.idxs_signals.get(idx).cloned().unwrap();
let scalar = match values.get(&sig.to_string()) {
Some(v) => val_to_scalar(v),
None => panic!("Input/witness variable does not have matching evaluation"),
};
for (var, val) in prover_data.r1cs.vars.iter().zip(&values) {
let scalar = val_to_scalar(val);
// input
itrans.insert(*idx, inp.len());
inp.push(scalar.to_bytes());
}
for (sig, idx) in &prover_data.r1cs.signal_idxs {
let scalar = match values.get(&sig.to_string()) {
Some(v) => val_to_scalar(v),
None => panic!("Input/witness variable does not have matching evaluation"),
};
if !prover_data.r1cs.public_idxs.contains(idx) {
// witness
trans.insert(*idx, wit.len());
itrans.insert(inp.len(), *var);
trans.insert(*var, inp.len());
if let VarType::Inst = var.ty() {
inp.push(scalar.to_bytes());
} else {
wit.push(scalar.to_bytes());
}
}
assert_eq!(wit.len() + inp.len(), prover_data.r1cs.next_idx);
assert_eq!(wit.len() + inp.len(), prover_data.r1cs.vars.len());
let num_vars = wit.len();
let const_id = wit.len();
@@ -135,10 +121,6 @@ pub fn r1cs_to_spartan(
let num_inputs = inp.len();
let assn_inputs = InputsAssignment::new(&inp).unwrap();
for (cid, sid) in itrans {
trans.insert(cid, sid + const_id + 1);
}
// circuit
let mut m_a: Vec<(usize, usize, [u8; 32])> = Vec::new();
let mut m_b: Vec<(usize, usize, [u8; 32])> = Vec::new();
@@ -183,24 +165,22 @@ pub fn r1cs_to_spartan(
)
}
fn eval_inputs(
inputs_map: &HashMap<String, Value>,
prover_data: &ProverData,
) -> HashMap<String, Value> {
for (input, sort) in &prover_data.precompute_inputs {
let value = inputs_map
.get(input)
.unwrap_or_else(|| panic!("No input for {}", input));
let sort2 = value.sort();
assert_eq!(
sort, &sort2,
"Sort mismatch for {input}. Expected\n\t{sort} but got\n\t{sort2}",
);
}
let new_map = prover_data.precompute.eval(inputs_map);
prover_data.r1cs.check_all(&new_map);
new_map
fn eval_inputs(inputs_map: &HashMap<String, Value>, prover_data: &ProverData) -> Vec<Value> {
let mut evaluator = StagedWitCompEvaluator::new(&prover_data.precompute);
let mut ffs = Vec::new();
ffs.extend(
evaluator
.eval_stage(inputs_map.clone())
.into_iter()
.cloned(),
);
ffs.extend(
evaluator
.eval_stage(Default::default())
.into_iter()
.cloned(),
);
ffs
}
fn val_to_scalar(v: &Value) -> Scalar {
@@ -228,7 +208,7 @@ fn int_to_scalar(i: &Integer) -> Scalar {
}
// circ Lc (const, monomials <Integer>) -> Vec<Variable>
fn lc_to_v(lc: &Lc, const_id: usize, trans: &HashMap<usize, usize>) -> Vec<Variable> {
fn lc_to_v(lc: &Lc, const_id: usize, trans: &HashMap<Var, usize>) -> Vec<Variable> {
let mut v: Vec<Variable> = Vec::new();
for (k, m) in &lc.monomials {

View File

@@ -4,14 +4,12 @@
//! thesis](https://github.com/circify/circ/tree/master/doc/resources/braun-bs-thesis.pdf)
//! is a good intro to how this process works.
use crate::cfg::CircCfg;
use crate::ir::term::precomp::PreComp;
use crate::ir::term::*;
use crate::target::bitsize;
use crate::target::r1cs::*;
use circ_fields::FieldT;
use circ_opt::FieldDivByZero;
use fxhash::FxHashSet;
use log::debug;
use rug::ops::Pow;
use rug::Integer;
@@ -39,29 +37,27 @@ enum EmbeddedTerm {
}
struct ToR1cs<'cfg> {
r1cs: R1cs<String>,
r1cs: R1cs,
cache: TermMap<EmbeddedTerm>,
wit_ext: PreComp,
public_inputs: FxHashSet<String>,
next_idx: usize,
zero: TermLc,
one: TermLc,
cfg: &'cfg CircCfg,
field: FieldT,
used_vars: HashSet<String>,
}
impl<'cfg> ToR1cs<'cfg> {
fn new(cfg: &'cfg CircCfg, public_inputs: FxHashSet<String>) -> Self {
fn new(cfg: &'cfg CircCfg, precompute: precomp::PreComp, used_vars: HashSet<String>) -> Self {
let field = cfg.field().clone();
debug!("Starting R1CS back-end, field: {}", field);
let r1cs = R1cs::new(field.clone());
let r1cs = R1cs::new(field.clone(), precompute);
let zero = TermLc(pf_lit(field.new_v(0u8)), r1cs.zero());
let one = zero.clone() + 1;
Self {
r1cs,
cache: TermMap::default(),
wit_ext: precomp::PreComp::new(),
public_inputs,
used_vars,
next_idx: 0,
zero,
one,
@@ -74,19 +70,24 @@ impl<'cfg> ToR1cs<'cfg> {
/// If values are being recorded, `value` must be provided.
///
/// `comp` is a term that computes the value.
fn fresh_var<D: Display + ?Sized>(&mut self, ctx: &D, comp: Term, public: bool) -> TermLc {
let n = format!("{}_n{}", ctx, self.next_idx);
fn fresh_var<D: Display + ?Sized>(&mut self, ctx: &D, comp: Term, ty: VarType) -> TermLc {
let n = if matches!(ty, VarType::Chall) {
format!("{ctx}")
} else {
format!("{ctx}_n{}", self.next_idx)
};
self.next_idx += 1;
debug_assert!(matches!(check(&comp), Sort::Field(_)));
self.r1cs.add_signal(n.clone(), comp.clone());
self.wit_ext.add_output(n.clone(), comp.clone());
if public {
self.r1cs.publicize(&n);
}
debug!("fresh: {}", n);
self.r1cs.add_var(n.clone(), comp.clone(), ty);
debug!("fresh: {n}");
TermLc(comp, self.r1cs.signal_lc(&n))
}
/// Get a new witness. See [ToR1cs::fresh_var].
fn fresh_wit<D: Display + ?Sized>(&mut self, ctx: &D, comp: Term) -> TermLc {
self.fresh_var(ctx, comp, VarType::FinalWit)
}
/// Enforce `x` to be bit-valued
fn enforce_bit(&mut self, b: TermLc) {
self.r1cs
@@ -98,7 +99,7 @@ impl<'cfg> ToR1cs<'cfg> {
fn fresh_bit<D: Display + ?Sized>(&mut self, ctx: &D, comp: Term) -> TermLc {
debug_assert!(matches!(check(&comp), Sort::Bool));
let comp = term![Op::Ite; comp, self.one.0.clone(), self.zero.0.clone()];
let v = self.fresh_var(ctx, comp, false);
let v = self.fresh_var(ctx, comp, VarType::FinalWit);
//debug!("Fresh bit: {}", self.r1cs.format_lc(&v));
self.enforce_bit(v.clone());
v
@@ -110,15 +111,13 @@ impl<'cfg> ToR1cs<'cfg> {
let eqz = term![Op::Eq; x.0.clone(), self.zero.0.clone()];
// m * x - 1 + is_zero == 0
// is_zero * x == 0
let m = self.fresh_var(
let m = self.fresh_wit(
"is_zero_inv",
term![Op::Ite; eqz.clone(), self.zero.0.clone(), term![PF_RECIP; x.0.clone()]],
false,
);
let is_zero = self.fresh_var(
let is_zero = self.fresh_wit(
"is_zero",
term![Op::Ite; eqz, self.one.0.clone(), self.zero.0.clone()],
false,
);
self.r1cs
.constraint(m.1, x.1.clone(), -is_zero.1.clone() + 1);
@@ -208,7 +207,7 @@ impl<'cfg> ToR1cs<'cfg> {
/// Return the product of `a` and `b`.
fn mul(&mut self, a: TermLc, b: TermLc) -> TermLc {
let mul_val = term![PF_MUL; a.0, b.0];
let c = self.fresh_var("mul", mul_val, false);
let c = self.fresh_wit("mul", mul_val);
self.r1cs.constraint(a.1, b.1, c.1.clone());
c
}
@@ -255,6 +254,48 @@ impl<'cfg> ToR1cs<'cfg> {
self.mul(c, t - f) + f
}
/// Embed this variable as
fn embed_var(&mut self, var: &Term, ty: VarType) {
assert!(
!self.cache.contains_key(var),
"already have var {}",
var.op()
);
assert!(!matches!(ty, VarType::CWit), "Unimplemented");
if !self.used_vars.contains(var.as_var_name()) {
return;
}
let public = matches!(ty, VarType::Inst);
match var.op() {
Op::Var(name, Sort::Bool) => {
let comp = term![Op::Ite; var.clone(), self.one.0.clone(), self.zero.0.clone()];
let lc = self.fresh_var(name, comp, ty);
if !public {
self.enforce_bit(lc.clone());
}
self.cache.insert(var.clone(), EmbeddedTerm::Bool(lc));
}
Op::Var(name, Sort::BitVector(n_bits)) => {
let public = matches!(ty, VarType::Inst);
let lc = self.fresh_var(
name,
term![Op::UbvToPf(self.field.clone()); var.clone()],
ty,
);
self.set_bv_uint(var.clone(), lc, *n_bits);
if !public {
self.get_bv_bits(var);
}
}
Op::Var(name, Sort::Field(f)) => {
assert_eq!(f, &self.field);
let lc = self.fresh_var(name, var.clone(), ty);
self.cache.insert(var.clone(), EmbeddedTerm::Field(lc));
}
o => unreachable!("Unhandled variable operator {}", o),
}
}
fn embed(&mut self, t: Term) {
debug!("Embed: {}", t);
for c in PostOrderIter::new(t) {
@@ -352,15 +393,7 @@ impl<'cfg> ToR1cs<'cfg> {
// TODO: skip if already embedded
if !self.cache.contains_key(&c) {
let lc = match &c.op() {
Op::Var(name, Sort::Bool) => {
let public = self.public_inputs.contains(name);
let comp = term![Op::Ite; c.clone(), self.one.0.clone(), self.zero.0.clone()];
let v = self.fresh_var(name, comp, public);
if !public {
self.enforce_bit(v.clone());
}
v
}
Op::Var(..) => panic!("call embed_var instead"),
Op::Const(Value::Bool(b)) => self.zero.clone() + *b as isize,
Op::Eq => self.embed_eq(&c.cs()[0], &c.cs()[1]),
Op::Ite => {
@@ -563,18 +596,7 @@ impl<'cfg> ToR1cs<'cfg> {
if let Sort::BitVector(n) = check(&bv) {
if !self.cache.contains_key(&bv) {
match &bv.op() {
Op::Var(name, Sort::BitVector(_)) => {
let public = self.public_inputs.contains(name);
let var = self.fresh_var(
name,
term![Op::UbvToPf(self.field.clone()); bv.clone()],
public,
);
self.set_bv_uint(bv.clone(), var, n);
if !public {
self.get_bv_bits(&bv);
}
}
Op::Var(..) => panic!("call embed_var instead"),
Op::Const(Value::BitVector(b)) => {
let bit_lcs = (0..b.width())
.map(|i| self.zero.clone() + b.uint().get_bit(i as u32) as isize)
@@ -708,8 +730,8 @@ impl<'cfg> ToR1cs<'cfg> {
let b_bv_term = term![Op::PfToBv(n); b.0.clone()];
let q_term = term![Op::UbvToPf(self.field.clone()); term![BV_UDIV; a_bv_term.clone(), b_bv_term.clone()]];
let r_term = term![Op::UbvToPf(self.field.clone()); term![BV_UREM; a_bv_term, b_bv_term]];
let q = self.fresh_var("div_q", q_term, false);
let r = self.fresh_var("div_r", r_term, false);
let q = self.fresh_wit("div_q", q_term);
let r = self.fresh_wit("div_r", r_term);
let qb = self.bitify("div_q", &q, n, false);
let rb = self.bitify("div_r", &r, n, false);
self.r1cs.constraint(q.1.clone(), b.1.clone(), (a - &r).1);
@@ -880,10 +902,7 @@ impl<'cfg> ToR1cs<'cfg> {
// TODO: skip if already embedded
if !self.cache.contains_key(&c) {
let lc = match &c.op() {
Op::Var(name, Sort::Field(_)) => {
let public = self.public_inputs.contains(name);
self.fresh_var(name, c.clone(), public)
}
Op::Var(..) => panic!("call embed_var instead"),
Op::Const(Value::Field(r)) => TermLc(
c.clone(),
self.r1cs.constant(r.as_ty_ref(&self.r1cs.modulus)),
@@ -914,7 +933,7 @@ impl<'cfg> ToR1cs<'cfg> {
FieldDivByZero::Incomplete => {
// ix = 1
let x = self.get_pf(&c.cs()[0]).clone();
let inv_x = self.fresh_var("recip", term![PF_RECIP; x.0.clone()], false);
let inv_x = self.fresh_wit("recip", term![PF_RECIP; x.0.clone()]);
self.r1cs
.constraint(x.1, inv_x.1.clone(), self.r1cs.zero() + 1);
inv_x
@@ -923,7 +942,7 @@ impl<'cfg> ToR1cs<'cfg> {
// ixx = x
let x = self.get_pf(&c.cs()[0]).clone();
let x2 = self.mul(x.clone(), x.clone());
let inv_x = self.fresh_var("recip", term![PF_RECIP; x.0.clone()], false);
let inv_x = self.fresh_wit("recip", term![PF_RECIP; x.0.clone()]);
self.r1cs.constraint(x2.1, inv_x.1.clone(), x.1);
inv_x
}
@@ -933,15 +952,13 @@ impl<'cfg> ToR1cs<'cfg> {
// zi = 0
let x = self.get_pf(&c.cs()[0]).clone();
let eqz = term![Op::Eq; x.0.clone(), self.zero.0.clone()];
let i = self.fresh_var(
let i = self.fresh_wit(
"is_zero_inv",
term![Op::Ite; eqz.clone(), self.zero.0.clone(), term![PF_RECIP; x.0.clone()]],
false,
);
let z = self.fresh_var(
let z = self.fresh_wit(
"is_zero",
term![Op::Ite; eqz, self.one.0.clone(), self.zero.0.clone()],
false,
);
self.r1cs
.constraint(i.1.clone(), x.1.clone(), -z.1.clone() + 1);
@@ -974,34 +991,40 @@ impl<'cfg> ToR1cs<'cfg> {
///
/// * Prover data (including the R1CS instance)
/// * Verifier data
pub fn to_r1cs(mut cs: Computation, cfg: &CircCfg) -> (ProverData, VerifierData) {
let assertions = cs.outputs.clone();
let metadata = cs.metadata.clone();
let public_inputs = metadata.public_input_names_set();
pub fn to_r1cs(cs: &Computation, cfg: &CircCfg) -> R1cs {
let public_inputs = cs.metadata.public_input_names_set();
debug!("public inputs: {:?}", public_inputs);
let mut converter = ToR1cs::new(cfg, public_inputs);
let used_vars = extras::free_variables(term(Op::Tuple, cs.outputs.clone()));
let mut converter = ToR1cs::new(cfg, cs.precomputes.clone(), used_vars);
debug!(
"Term count: {}",
assertions
cs.outputs
.iter()
.map(|c| PostOrderIter::new(c.clone()).count())
.sum::<usize>()
);
debug!("declaring inputs");
for i in metadata.ordered_public_inputs() {
debug!("input {}", i);
converter.embed(i);
let vars = cs.metadata.interactive_vars();
for i in &vars.instances {
converter.embed_var(i, VarType::Inst);
}
for round in &vars.rounds {
for w in &round.witnesses {
converter.embed_var(w, VarType::RoundWit);
}
for c in &round.challenges {
converter.embed_var(c, VarType::Chall);
}
converter.r1cs.end_round()
}
for w in &vars.final_witnesses {
converter.embed_var(w, VarType::FinalWit);
}
debug!("Printing assertions");
for c in assertions {
converter.assert(c);
for c in &cs.outputs {
converter.assert(c.clone());
}
debug!("r1cs public inputs: {:?}", converter.r1cs.public_idxs,);
cs.precomputes = cs.precomputes.sequential_compose(&converter.wit_ext);
let r1cs = converter.r1cs;
let verifier_data = r1cs.verifier_data(&cs);
let prover_data = r1cs.prover_data(&cs);
(prover_data, verifier_data)
converter.r1cs
}
#[cfg(test)]
@@ -1015,14 +1038,14 @@ pub mod test {
use fxhash::FxHashMap;
use quickcheck_macros::quickcheck;
fn to_r1cs_dflt(cs: Computation) -> (ProverData, VerifierData) {
to_r1cs(cs, &CircCfg::default())
fn to_r1cs_dflt(cs: Computation) -> R1cs {
to_r1cs(&cs, &CircCfg::default())
}
fn to_r1cs_mod17(cs: Computation) -> (ProverData, VerifierData) {
fn to_r1cs_mod17(cs: Computation) -> R1cs {
let mut opt = crate::cfg::CircOpt::default();
opt.field.custom_modulus = "17".into();
to_r1cs(cs, &CircCfg::from(opt))
to_r1cs(&cs, &CircCfg::from(opt))
}
fn init() {
@@ -1048,10 +1071,8 @@ pub mod test {
leaf_term(Op::Var("b".to_owned(), Sort::Bool)),
],
);
let (pd, _) = to_r1cs_mod17(cs);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs_mod17(cs);
r1cs.check_all(&values);
}
#[quickcheck]
@@ -1062,10 +1083,8 @@ pub mod test {
term![Op::Not; t]
};
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new());
let (pd, _) = to_r1cs_dflt(cs);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs_dflt(cs);
r1cs.check_all(&values);
}
#[quickcheck]
@@ -1076,10 +1095,8 @@ pub mod test {
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let cfg = CircCfg::default();
let (pd, _) = to_r1cs(cs, &cfg);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs(&cs, &cfg);
r1cs.check_all(&values);
}
#[quickcheck]
@@ -1088,12 +1105,10 @@ pub mod test {
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs = Computation::from_constraint_system_parts(vec![t], Vec::new());
let cfg = CircCfg::default();
let (pd, _) = to_r1cs(cs, &cfg);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs2 = reduce_linearities(pd.r1cs, &cfg);
r1cs2.check_all(&extended_values);
let r1cs = to_r1cs(&cs, &cfg);
r1cs.check_all(&values);
let r1cs2 = reduce_linearities(r1cs, &cfg);
r1cs2.check_all(&values);
}
#[quickcheck]
@@ -1104,12 +1119,10 @@ pub mod test {
crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let cfg = CircCfg::default();
let (pd, _) = to_r1cs(cs, &cfg);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs2 = reduce_linearities(pd.r1cs, &cfg);
r1cs2.check_all(&extended_values);
let r1cs = to_r1cs(&cs, &cfg);
r1cs.check_all(&values);
let r1cs2 = reduce_linearities(r1cs, &cfg);
r1cs2.check_all(&values);
}
#[test]
@@ -1126,10 +1139,8 @@ pub mod test {
term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]]],
vec![leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))],
);
let (pd, _) = to_r1cs_dflt(cs);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs_dflt(cs);
r1cs.check_all(&values);
}
#[test]
@@ -1143,12 +1154,10 @@ pub mod test {
let t = term![Op::Eq; t, leaf_term(Op::Const(v))];
let cs = Computation::from_constraint_system_parts(vec![t], vec![]);
let cfg = CircCfg::default();
let (pd, _) = to_r1cs(cs, &cfg);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs2 = reduce_linearities(pd.r1cs, &cfg);
r1cs2.check_all(&extended_values);
let r1cs = to_r1cs(&cs, &cfg);
r1cs.check_all(&values);
let r1cs2 = reduce_linearities(r1cs, &cfg);
r1cs2.check_all(&values);
}
fn pf_dflt(i: isize) -> Term {
@@ -1158,10 +1167,8 @@ pub mod test {
fn const_test(term: Term) {
let mut cs = Computation::new();
cs.assert(term);
let (pd, _) = to_r1cs_dflt(cs);
let precomp = pd.precompute;
let extended_values = precomp.eval(&Default::default());
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs_dflt(cs);
r1cs.check_all(&Default::default());
}
#[test]
@@ -1246,6 +1253,7 @@ pub mod test {
#[test]
fn pf2bv_lit() {
init();
const_test(term![
Op::Eq;
term![Op::PfToBv(4); pf_dflt(8)],
@@ -1282,9 +1290,7 @@ pub mod test {
],
);
crate::ir::opt::tuple::eliminate_tuples(&mut cs);
let (pd, _) = to_r1cs_mod17(cs);
let precomp = pd.precompute;
let extended_values = precomp.eval(&values);
pd.r1cs.check_all(&extended_values);
let r1cs = to_r1cs_mod17(cs);
r1cs.check_all(&values);
}
}

326
src/target/r1cs/wit_comp.rs Normal file
View File

@@ -0,0 +1,326 @@
//! A multi-stage R1CS witness evaluator.
use crate::ir::term::*;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use serde::{Deserialize, Serialize};
use log::trace;
/// A witness computation that proceeds in stages.
///
/// In each stage:
/// * it takes a partial assignment
/// * it returns a vector of field values
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct StagedWitComp {
vars: HashSet<String>,
stages: Vec<Stage>,
steps: Vec<(Op, usize)>,
step_args: Vec<usize>,
ouput_steps: Vec<usize>,
// we don't serialize the cache; it's just used during construction, and terms are expensive to
// serialize.
#[serde(skip)]
term_to_step: TermMap<usize>,
}
/// Specifies a stage.
#[derive(Debug, Serialize, Deserialize)]
pub struct Stage {
inputs: HashMap<String, Sort>,
num_outputs: usize,
}
/// Builder interface
impl StagedWitComp {
/// Add a new stage.
#[allow(clippy::uninlined_format_args)]
pub fn add_stage(&mut self, inputs: HashMap<String, Sort>, output_values: Vec<Term>) {
let stage = Stage {
inputs,
num_outputs: output_values.len(),
};
for input in stage.inputs.keys() {
debug_assert!(!self.vars.contains(input), "Duplicate input {}", input);
}
self.vars.extend(stage.inputs.keys().cloned());
self.stages.push(stage);
let already_have: TermSet = self.term_to_step.keys().cloned().collect();
for t in PostOrderIter::from_roots_and_skips(output_values.clone(), already_have) {
self.add_step(t);
}
for t in output_values {
self.ouput_steps.push(*self.term_to_step.get(&t).unwrap());
}
}
fn add_step(&mut self, term: Term) {
debug_assert!(!self.term_to_step.contains_key(&term));
let step_idx = self.steps.len();
if let Op::Var(name, _) = term.op() {
debug_assert!(self.vars.contains(name));
}
for child in term.cs() {
let child_step = self.term_to_step.get(child).unwrap();
self.step_args.push(*child_step);
}
self.steps.push((term.op().clone(), self.step_args.len()));
self.term_to_step.insert(term, step_idx);
}
/// How many stages are there?
pub fn stage_sizes(&self) -> impl Iterator<Item = usize> + '_ {
self.stages.iter().map(|s| s.num_outputs)
}
/// How many inputs are there for this stage?
pub fn num_stage_inputs(&self, n: usize) -> usize {
self.stages[n].inputs.len()
}
}
/// Evaluator interface
impl StagedWitComp {
fn step_args(&self, step_idx: usize) -> impl Iterator<Item = usize> + '_ {
assert!(step_idx < self.steps.len());
let args_end = self.steps[step_idx].1;
let args_start = if step_idx == 0 {
0
} else {
self.steps[step_idx - 1].1
};
(args_start..args_end).map(move |step_arg_idx| self.step_args[step_arg_idx])
}
}
/// Evaluates a staged witness computation.
#[derive(Debug)]
pub struct StagedWitCompEvaluator<'a> {
comp: &'a StagedWitComp,
variable_values: HashMap<String, Value>,
step_values: Vec<Value>,
stages_evaluated: usize,
outputs_evaluted: usize,
}
impl<'a> StagedWitCompEvaluator<'a> {
/// Create an empty witness computation.
pub fn new(comp: &'a StagedWitComp) -> Self {
Self {
comp,
variable_values: Default::default(),
step_values: Default::default(),
stages_evaluated: Default::default(),
outputs_evaluted: 0,
}
}
/// Have all stages been evaluated?
pub fn is_done(&self) -> bool {
self.stages_evaluated == self.comp.stages.len()
}
fn eval_step(&mut self) {
let next_step_idx = self.step_values.len();
assert!(next_step_idx < self.comp.steps.len());
let op = &self.comp.steps[next_step_idx].0;
let args: Vec<&Value> = self
.comp
.step_args(next_step_idx)
.map(|i| &self.step_values[i])
.collect();
let value = eval_op(op, &args, &self.variable_values);
trace!(
"Eval step {}: {} on {:?} -> {}",
next_step_idx,
op,
args,
value
);
self.step_values.push(value);
}
/// Evaluate one stage.
pub fn eval_stage(&mut self, inputs: HashMap<String, Value>) -> Vec<&Value> {
trace!(
"Beginning stage {}/{}",
self.stages_evaluated,
self.comp.stages.len()
);
debug_assert!(self.stages_evaluated < self.comp.stages.len());
let stage = &self.comp.stages[self.stages_evaluated];
let num_outputs = stage.num_outputs;
self.variable_values.extend(inputs);
if num_outputs > 0 {
let max_step = (0..num_outputs)
.map(|i| {
let new_output_i = i + self.outputs_evaluted;
self.comp.ouput_steps[new_output_i]
})
.max()
.unwrap();
while self.step_values.len() <= max_step {
self.eval_step();
}
}
self.outputs_evaluted += num_outputs;
self.stages_evaluated += 1;
let mut out = Vec::new();
for output_step in
&self.comp.ouput_steps[self.outputs_evaluted - num_outputs..self.outputs_evaluted]
{
out.push(&self.step_values[*output_step]);
}
out
}
}
#[cfg(test)]
mod test {
use rug::Integer;
use super::*;
use circ_fields::FieldT;
fn mk_inputs(v: Vec<(String, Sort)>) -> HashMap<String, Sort> {
v.into_iter().collect()
}
#[test]
fn one_const() {
let mut comp = StagedWitComp::default();
let field = FieldT::from(Integer::from(7));
comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]);
let mut evaluator = StagedWitCompEvaluator::new(&comp);
let output = evaluator.eval_stage(Default::default());
let ex_output: &[usize] = &[0];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
assert!(evaluator.is_done());
}
#[test]
fn many_const() {
let mut comp = StagedWitComp::default();
let field = FieldT::from(Integer::from(7));
comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]);
comp.add_stage(
mk_inputs(vec![]),
vec![pf_lit(field.new_v(1)), pf_lit(field.new_v(4))],
);
comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(6))]);
comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]);
let mut evaluator = StagedWitCompEvaluator::new(&comp);
let output = evaluator.eval_stage(Default::default());
let ex_output: &[usize] = &[0];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
let output = evaluator.eval_stage(Default::default());
let ex_output: &[usize] = &[1, 4];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
let output = evaluator.eval_stage(Default::default());
let ex_output: &[usize] = &[6];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
let output = evaluator.eval_stage(Default::default());
let ex_output: &[usize] = &[0];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
assert!(evaluator.is_done());
}
#[test]
fn vars_one_stage() {
let mut comp = StagedWitComp::default();
let field = FieldT::from(Integer::from(7));
comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]),
vec![
leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))),
term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))],
]);
let mut evaluator = StagedWitCompEvaluator::new(&comp);
let output = evaluator.eval_stage(
vec![
("a".into(), Value::Bool(true)),
("b".into(), Value::Field(field.new_v(5))),
]
.into_iter()
.collect(),
);
let ex_output: &[usize] = &[5, 1];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
assert!(evaluator.is_done());
}
#[test]
fn vars_many_stages() {
let mut comp = StagedWitComp::default();
let field = FieldT::from(Integer::from(7));
comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]),
vec![
leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))),
term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))],
]);
comp.add_stage(mk_inputs(vec![("c".into(), Sort::Field(field.clone()))]),
vec![
term![PF_ADD;
leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))),
leaf_term(Op::Var("c".into(), Sort::Field(field.clone())))],
term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))],
term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(0)), pf_lit(field.new_v(1))],
]);
let mut evaluator = StagedWitCompEvaluator::new(&comp);
let output = evaluator.eval_stage(
vec![
("a".into(), Value::Bool(true)),
("b".into(), Value::Field(field.new_v(5))),
]
.into_iter()
.collect(),
);
let ex_output: &[usize] = &[5, 1];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
let output = evaluator.eval_stage(
vec![("c".into(), Value::Field(field.new_v(3)))]
.into_iter()
.collect(),
);
let ex_output: &[usize] = &[1, 1, 0];
assert_eq!(output.len(), ex_output.len());
for i in 0..ex_output.len() {
assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}");
}
assert!(evaluator.is_done());
}
}

View File

@@ -4,8 +4,8 @@ from os import path
# Gloable variables
feature_path = ".features.txt"
mode_path = ".mode.txt"
cargo_features = {"aby", "c", "lp", "r1cs",
"smt", "zok", "datalog", "bellman", "spartan", "kahip", "kahypar"}
cargo_features = {"aby", "c", "lp", "r1cs", "kahip", "kahypar",
"smt", "zok", "datalog", "bellman", "spartan"}
# Environment variables
ABY_SOURCE = "./../ABY"