diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6670d82b..400718c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,43 +1,47 @@ name: Build & Test on: - push: - branches: [ master, ci ] - pull_request: - branches: [ master, ci ] + push: + branches: [master, ci] + pull_request: + branches: [master, ci] env: - CARGO_TERM_COLOR: always - ABY_SOURCE: "./../ABY" + CARGO_TERM_COLOR: always + ABY_SOURCE: "./../ABY" + KAHIP_SOURCE: "./../KaHIP" + KAHYPAR_SOURCE: "./../kahypar" jobs: - build: + build: + runs-on: ubuntu-latest - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - if: runner.os == 'Linux' - run: sudo apt-get update; sudo apt-get install zsh cvc4 libboost-all-dev libssl-dev coinor-cbc coinor-libcbc-dev - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - - uses: Swatinem/rust-cache@v1 - - name: Set all features on - run: python3 driver.py --all_features - - name: Install third_party libraries - run: python3 driver.py --install - - name: Cache third_party build - uses: actions/cache@v2 - with: - path: ${ABY_SOURCE}/build - key: ${{ runner.os }} - - name: Check - run: python3 driver.py --check - - name: Format - run: cargo fmt -- --check - - name: Lint - run: python3 driver.py --lint - - name: Build, then Test - run: python3 driver.py --test + steps: + - uses: actions/checkout@v3 + - name: Install dependencies + if: runner.os == 'Linux' + run: sudo apt-get update; sudo apt-get install zsh cvc4 libboost-all-dev libssl-dev coinor-cbc coinor-libcbc-dev + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - uses: Swatinem/rust-cache@v2 + - name: Set all features on + run: python3 driver.py --all_features + - name: Cache third_party libraries + uses: actions/cache@v3 + with: + path: | + /home/runner/work/circ/ABY + /home/runner/work/circ/KaHIP + /home/runner/work/circ/kahypar + key: ${{ runner.os }}-third_party + - name: Install third_party libraries + run: python3 driver.py --install + - name: Check + run: python3 driver.py --check + - name: Format + run: cargo fmt -- --check + - name: Lint + run: python3 driver.py --lint + - name: Build, then Test + run: python3 driver.py --test diff --git a/Cargo.lock b/Cargo.lock index 843b979a..e0bc8ba9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -278,6 +278,7 @@ dependencies = [ "quickcheck", "quickcheck_macros", "rand 0.8.5", + "rand_chacha 0.3.1", "rsmt2", "rug", "serde", diff --git a/Cargo.toml b/Cargo.toml index bdda3cd0..afe6e336 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ rug = { version = "1.11", features = ["serde"] } gmp-mpfr-sys = { version = "1.4", optional = true } lazy_static = { version = "1.4", optional = true } rand = "0.8" +rand_chacha = "0.3" rsmt2 = { version = "0.14", optional = true } ieee754 = { version = "0.2", optional = true} zokrates_parser = { path = "third_party/ZoKrates/zokrates_parser", optional = true } @@ -65,8 +66,8 @@ datalog = ["pest", "pest-ast", "pest_derive", "from-pest", "lazy_static"] smt = ["rsmt2", "ieee754"] lp = ["good_lp", "lp-solvers"] aby = ["lp"] -kahip = ["lp"] -kahypar = ["lp"] +kahip = [] +kahypar = [] r1cs = [] spartan = ["r1cs", "dep:spartan", "merlin", "curve25519-dalek", "bincode", "gmp-mpfr-sys"] bellman = ["r1cs", "dep:bellman", "ff", "group", "pairing", "serde_bytes", "bincode", "gmp-mpfr-sys"] diff --git a/circ_fields/src/lib.rs b/circ_fields/src/lib.rs index e13a0b3f..cf9c6e0f 100644 --- a/circ_fields/src/lib.rs +++ b/circ_fields/src/lib.rs @@ -20,14 +20,14 @@ use ff::Field; use paste::paste; use rug::Integer; use serde::{Deserialize, Serialize}; -use std::fmt::{self, Display, Formatter}; +use std::fmt::{self, Debug, Display, Formatter}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::sync::Arc; // TODO: rework this using macros? /// Field element type -#[derive(PartialEq, Eq, Clone, Debug, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[derive(PartialEq, Eq, Clone, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum FieldT { /// BLS12-381 scalar field as `ff` FBls12381, @@ -47,6 +47,12 @@ impl Display for FieldT { } } +impl Debug for FieldT { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self) + } +} + impl From> for FieldT { fn from(m: Arc) -> Self { Self::as_ff_opt(m.as_ref()).unwrap_or(Self::IntField(m)) @@ -128,7 +134,7 @@ impl FieldT { } /// Field element value -#[derive(PartialEq, Eq, Clone, Debug, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[derive(PartialEq, Eq, Clone, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum FieldV { /// BLS12-381 scalar field element as `ff` FBls12381(FBls12381), @@ -272,15 +278,37 @@ impl FieldV { ty.new_v(self.i()) } } + + /// Get this field element as an SMT-LIB string. + #[inline] + pub fn as_smtlib(&self) -> String { + match self { + Self::FBls12381(pf) => format!("#f{}m{}", Integer::from(pf), &*F_BLS12381_FMOD), + Self::FBn254(pf) => format!("#f{}m{}", Integer::from(pf), &*F_BN254_FMOD), + Self::IntField(i) => format!("{}", i), + } + } + + /// Get this field element as a signed integer. No particular cut-off is guaranteed. + #[inline] + pub fn signed_int(&self) -> Integer { + let mut i = self.i(); + if i.significant_bits() >= self.ty().modulus().significant_bits() - 1 { + i -= self.ty().modulus(); + } + i + } } impl Display for FieldV { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - Self::FBls12381(pf) => write!(f, "#f{}m{}", Integer::from(pf), &*F_BLS12381_FMOD), - Self::FBn254(pf) => write!(f, "#f{}m{}", Integer::from(pf), &*F_BN254_FMOD), - Self::IntField(i) => i.fmt(f), - } + write!(f, "{}", self.signed_int()) + } +} + +impl Debug for FieldV { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{} (in {})", self.signed_int(), self.ty()) } } diff --git a/circ_hc/src/hashconsing/example_u8.rs b/circ_hc/src/hashconsing/example_u8.rs index a30fc2fa..75e1ff37 100644 --- a/circ_hc/src/hashconsing/example_u8.rs +++ b/circ_hc/src/hashconsing/example_u8.rs @@ -49,6 +49,10 @@ impl crate::Table for Table { "hashconsing" } + fn for_each(f: impl FnMut(&u8, &[Self::Node])) { + panic!() + } + fn reserve(num_nodes: usize) { FACTORY.reserve(num_nodes); } diff --git a/circ_hc/src/hashconsing/macro_.rs b/circ_hc/src/hashconsing/macro_.rs index def52138..1ad72535 100644 --- a/circ_hc/src/hashconsing/macro_.rs +++ b/circ_hc/src/hashconsing/macro_.rs @@ -52,6 +52,10 @@ macro_rules! generate_hashcons_hashconsing { "hashconsing" } + fn for_each(f: impl FnMut(&$Op, &[Self::Node])) { + panic!() + } + fn reserve(num_nodes: usize) { FACTORY.reserve(num_nodes); } diff --git a/circ_hc/src/hashconsing/template.rs b/circ_hc/src/hashconsing/template.rs index 88127a54..0023eea3 100644 --- a/circ_hc/src/hashconsing/template.rs +++ b/circ_hc/src/hashconsing/template.rs @@ -50,6 +50,10 @@ impl crate::Table for Table { "hashconsing" } + fn for_each(f: impl FnMut(&TemplateOp, &[Self::Node])) { + panic!() + } + fn reserve(num_nodes: usize) { FACTORY.reserve(num_nodes); } diff --git a/circ_hc/src/lib.rs b/circ_hc/src/lib.rs index 0697a273..b6cec58a 100644 --- a/circ_hc/src/lib.rs +++ b/circ_hc/src/lib.rs @@ -39,6 +39,9 @@ pub trait Table { /// The name of the implementation fn name() -> &'static str; + /// Fun a function on every node + fn for_each(f: impl FnMut(&Op, &[Self::Node])); + /// When the table garbage-collects a node with ID `id`, it will call `f(id)`. `f(id)` might /// clear caches, etc., but instead of droping `Node`s, it should return them. /// diff --git a/circ_hc/src/raw/example_u8.rs b/circ_hc/src/raw/example_u8.rs index 6ffd0773..d71c3d0d 100644 --- a/circ_hc/src/raw/example_u8.rs +++ b/circ_hc/src/raw/example_u8.rs @@ -45,6 +45,10 @@ impl crate::Table for Table { "raw" } + fn for_each(mut f: impl FnMut(&u8, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_hc/src/raw/macro_.rs b/circ_hc/src/raw/macro_.rs index eb431603..c11f1898 100644 --- a/circ_hc/src/raw/macro_.rs +++ b/circ_hc/src/raw/macro_.rs @@ -48,6 +48,10 @@ macro_rules! generate_hashcons_raw { "raw" } + fn for_each(mut f: impl FnMut(&$Op, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_hc/src/raw/template.rs b/circ_hc/src/raw/template.rs index 1f24cea0..2287d400 100644 --- a/circ_hc/src/raw/template.rs +++ b/circ_hc/src/raw/template.rs @@ -45,6 +45,10 @@ impl crate::Table for Table { "raw" } + fn for_each(mut f: impl FnMut(&TemplateOp, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_hc/src/rc/example_u8.rs b/circ_hc/src/rc/example_u8.rs index 37576cb6..9ad9c00a 100644 --- a/circ_hc/src/rc/example_u8.rs +++ b/circ_hc/src/rc/example_u8.rs @@ -69,6 +69,10 @@ impl crate::Table for Table { "rc" } + fn for_each(mut f: impl FnMut(&u8, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_hc/src/rc/macro_.rs b/circ_hc/src/rc/macro_.rs index a4b8b9c7..ec846a9b 100644 --- a/circ_hc/src/rc/macro_.rs +++ b/circ_hc/src/rc/macro_.rs @@ -72,6 +72,10 @@ macro_rules! generate_hashcons_rc { "rc" } + fn for_each(mut f: impl FnMut(&$Op, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_hc/src/rc/template.rs b/circ_hc/src/rc/template.rs index b82a1712..924865bd 100644 --- a/circ_hc/src/rc/template.rs +++ b/circ_hc/src/rc/template.rs @@ -69,6 +69,10 @@ impl crate::Table for Table { "rc" } + fn for_each(mut f: impl FnMut(&TemplateOp, &[Self::Node])) { + MANAGER.with(|man| man.table.borrow().keys().for_each(|n| f(&n.op, &n.cs))); + } + fn reserve(num_nodes: usize) { MANAGER.with(|man| man.table.borrow_mut().reserve(num_nodes)) } diff --git a/circ_opt/README.md b/circ_opt/README.md index 1f423a0d..457b1b00 100644 --- a/circ_opt/README.md +++ b/circ_opt/README.md @@ -86,6 +86,13 @@ Options: [default: true] [possible values: true, false] + --fmt-hide-field + Always hide the field + + [env: FMT_HIDE_FIELD=] + [default: false] + [possible values: true, false] + --zsharp-isolate-asserts In Z#, "isolate" assertions. That is, assertions in if/then/else expressions only take effect if that branch is active. @@ -144,6 +151,8 @@ Options: Which field to use [env: IR_FIELD_TO_BV=] [default: wrap] [possible values: wrap, panic] --fmt-use-default-field Which field to use [env: FMT_USE_DEFAULT_FIELD=] [default: true] [possible values: true, false] + --fmt-hide-field + Always hide the field [env: FMT_HIDE_FIELD=] [default: false] [possible values: true, false] --zsharp-isolate-asserts In Z#, "isolate" assertions. That is, assertions in if/then/else expressions only take effect if that branch is active [env: ZSHARP_ISOLATE_ASSERTS=] [default: false] [possible values: true, false] --datalog-rec-limit @@ -179,6 +188,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -216,6 +226,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -251,6 +262,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -286,6 +298,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -321,6 +334,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -356,6 +370,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -391,6 +406,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -426,6 +442,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -464,6 +481,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -500,6 +518,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -538,6 +557,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: true, @@ -574,6 +594,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: true, @@ -612,6 +633,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, @@ -648,6 +670,7 @@ BinaryOpt { }, fmt: FmtOpt { use_default_field: true, + hide_field: false, }, zsharp: ZsharpOpt { isolate_asserts: false, diff --git a/circ_opt/src/lib.rs b/circ_opt/src/lib.rs index 0702890e..8fc794be 100644 --- a/circ_opt/src/lib.rs +++ b/circ_opt/src/lib.rs @@ -198,12 +198,20 @@ pub struct FmtOpt { action = ArgAction::Set, default_value = "true")] pub use_default_field: bool, + /// Always hide the field + #[arg( + long = "fmt-hide-field", + env = "FMT_HIDE_FIELD", + action = ArgAction::Set, + default_value = "false")] + pub hide_field: bool, } impl Default for FmtOpt { fn default() -> Self { Self { use_default_field: true, + hide_field: false, } } } diff --git a/driver.py b/driver.py index 0ffc3df3..28a7c68f 100755 --- a/driver.py +++ b/driver.py @@ -39,17 +39,15 @@ def install(features): subprocess.run(["./scripts/build_aby.zsh"]) if f == "kahip": if verify_path_empty(KAHIP_SOURCE): - # TODO: we can pull from their main repository instead of fork also long as we - # remove their parallel (ParHIP) cmake dependency subprocess.run( - ["git", "clone", "https://github.com/edwjchen/KaHIP.git", KAHIP_SOURCE] + ["git", "clone", "https://github.com/KaHIP/KaHIP.git", KAHIP_SOURCE] ) subprocess.run(["./scripts/build_kahip.zsh"]) if f == "kahypar": if verify_path_empty(KAHYPAR_SOURCE): subprocess.run( ["git", "clone", "--depth=1", "--recursive", - "git@github.com:SebastianSchlag/kahypar.git", KAHYPAR_SOURCE] + "https://github.com/SebastianSchlag/kahypar.git", KAHYPAR_SOURCE] ) subprocess.run(["./scripts/build_kahypar.zsh"]) @@ -107,7 +105,7 @@ def build(features): cmd += ["--examples"] if features: - cmd = cmd + ["--features"] + features + cmd = cmd + ["--features"] + [",".join(features)] if "ristretto255" in features: cmd = cmd + ["--no-default-features"] @@ -139,8 +137,8 @@ def test(features, extra_args): test_cmd = ["cargo", "test"] test_cmd_release = ["cargo", "test", "--release"] if features: - test_cmd += ["--features"] + features - test_cmd_release += ["--features"] + features + test_cmd += ["--features"] + [",".join(features)] + test_cmd_release += ["--features"] + [",".join(features)] if "ristretto255" in features: test_cmd += ["--no-default-features"] test_cmd_release += ["--no-default-features"] @@ -152,7 +150,7 @@ def test(features, extra_args): if load_mode() == "release": log_run_check(test_cmd_release) - if "r1cs" in features and "smt" in features: + if "r1cs" in features and "smt" in features and "datalog" in features: log_run_check(["./scripts/test_datalog.zsh"]) if "zok" in features and "smt" in features: @@ -191,7 +189,7 @@ def benchmark(features): cmd += ["--examples"] if features: - cmd = cmd + ["--features"] + features + cmd = cmd + ["--features"] + [",".join(features)] if "ristretto255" in features: cmd = cmd + ["--no-default-features"] log_run_check(cmd) @@ -215,7 +213,7 @@ def lint(): cmd = ["cargo", "clippy", "--tests", "--examples", "--benches", "--bins"] if features: - cmd = cmd + ["--features"] + features + cmd = cmd + ["--features"] + [",".join(features)] if "ristretto255" in features: cmd = cmd + ["--no-default-features"] log_run_check(cmd) @@ -224,7 +222,7 @@ def lint(): def flamegraph(features, extra): cmd = ["cargo", "flamegraph"] if features: - cmd = cmd + ["--features"] + features + cmd = cmd + ["--features"] + [",".join(features)] if "ristretto255" in features: cmd = cmd + ["--no-default-features"] cmd += extra diff --git a/examples/circ.rs b/examples/circ.rs index 6aab95f2..80a3af5c 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -34,10 +34,10 @@ use circ::ir::{ use circ::target::aby::trans::to_aby; #[cfg(feature = "lp")] use circ::target::ilp::{assignment_to_values, trans::to_ilp}; -#[cfg(feature = "bellman")] -use circ::target::r1cs::bellman::gen_params; #[cfg(feature = "spartan")] use circ::target::r1cs::spartan::write_data; +#[cfg(feature = "bellman")] +use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem}; #[cfg(feature = "r1cs")] use circ::target::r1cs::{opt::reduce_linearities, trans::to_r1cs}; #[cfg(feature = "smt")] @@ -96,6 +96,8 @@ enum Backend { lc_elimination_thresh: usize, #[arg(long, default_value = "count")] action: ProofAction, + #[arg(long, default_value = "groth16")] + proof_impl: ProofImpl, }, Smt {}, Ilp {}, @@ -135,6 +137,12 @@ enum ProofAction { SpartanSetup, } +#[derive(PartialEq, Eq, Debug, Clone, ValueEnum)] +enum ProofImpl { + Groth16, + Mirage, +} + fn determine_language(l: &Language, input_path: &Path) -> DeterminedLanguage { match *l { Language::Datalog => DeterminedLanguage::Datalog, @@ -275,25 +283,24 @@ fn main() { .. } => { println!("Converting to r1cs"); - let (mut prover_data, verifier_data) = to_r1cs(cs.get("main").clone(), cfg()); + let cs = cs.get("main"); + let mut r1cs = to_r1cs(cs, cfg()); - println!( - "Pre-opt R1cs size: {}", - prover_data.r1cs.constraints().len() - ); - prover_data.r1cs = reduce_linearities(prover_data.r1cs, cfg()); + println!("Pre-opt R1cs size: {}", r1cs.constraints().len()); + r1cs = reduce_linearities(r1cs, cfg()); - println!("Final R1cs size: {}", prover_data.r1cs.constraints().len()); + println!("Final R1cs size: {}", r1cs.constraints().len()); + let (prover_data, verifier_data) = r1cs.finalize(cs); match action { ProofAction::Count => (), #[cfg(feature = "bellman")] ProofAction::Setup => { println!("Generating Parameters"); - gen_params::( + Bellman::::setup_fs( + prover_data, + verifier_data, prover_key, verifier_key, - &prover_data, - &verifier_data, ) .unwrap(); } diff --git a/examples/zk.rs b/examples/zk.rs index 95423cd0..304d0e66 100644 --- a/examples/zk.rs +++ b/examples/zk.rs @@ -1,16 +1,18 @@ -use circ::cfg::clap::{self, Parser, ValueEnum}; +use circ::cfg::{ + clap::{self, Parser, ValueEnum}, + CircOpt, +}; use std::path::PathBuf; #[cfg(feature = "bellman")] use bls12_381::Bls12; #[cfg(feature = "bellman")] -use circ::target::r1cs::bellman; +use circ::target::r1cs::{bellman::Bellman, proof::ProofSystem}; #[cfg(feature = "spartan")] -use circ::target::r1cs::spartan; - -#[cfg(any(feature = "bellman", feature = "spartan"))] use circ::ir::term::text::parse_value_map; +#[cfg(feature = "spartan")] +use circ::target::r1cs::spartan; #[derive(Debug, Parser)] #[command(name = "zk", about = "The CirC ZKP runner")] @@ -27,8 +29,12 @@ struct Options { pin: PathBuf, #[arg(long, default_value = "vin")] vin: PathBuf, + #[arg(long, default_value = "groth16")] + proof_impl: ProofImpl, #[arg(long)] action: ProofAction, + #[command(flatten)] + circ: CircOpt, } #[derive(PartialEq, Debug, Clone, ValueEnum)] @@ -40,24 +46,33 @@ enum ProofAction { Spartan, } +#[derive(PartialEq, Debug, Clone, ValueEnum)] +/// Whether to use Groth16 or Mirage +enum ProofImpl { + Groth16, + Mirage, +} + fn main() { env_logger::Builder::from_default_env() .format_level(false) .format_timestamp(None) .init(); let opts = Options::parse(); + circ::cfg::set(&opts.circ); match opts.action { #[cfg(feature = "bellman")] ProofAction::Prove => { - let input_map = parse_value_map(&std::fs::read(opts.inputs).unwrap()); println!("Proving"); - bellman::prove::(opts.prover_key, opts.proof, &input_map).unwrap(); + Bellman::::prove_fs(opts.prover_key, opts.inputs, opts.proof).unwrap(); } #[cfg(feature = "bellman")] ProofAction::Verify => { - let input_map = parse_value_map(&std::fs::read(opts.inputs).unwrap()); println!("Verifying"); - bellman::verify::(opts.verifier_key, opts.proof, &input_map).unwrap(); + assert!( + Bellman::::verify_fs(opts.verifier_key, opts.inputs, opts.proof).unwrap(), + "invalid proof" + ); } #[cfg(feature = "spartan")] ProofAction::Spartan => { diff --git a/examples/zxc.rs b/examples/zxc.rs index a5b7c61b..cd68f0fe 100644 --- a/examples/zxc.rs +++ b/examples/zxc.rs @@ -129,16 +129,16 @@ fn main() { */ println!("Converting to r1cs"); - let (prover_data, _) = to_r1cs(cs.get("main").clone(), cfg()); + let r1cs = to_r1cs(cs.get("main"), cfg()); let r1cs = if options.skip_linred { println!("Skipping linearity reduction, as requested."); - prover_data.r1cs + r1cs } else { println!( "R1cs size before linearity reduction: {}", - prover_data.r1cs.constraints().len() + r1cs.constraints().len() ); - reduce_linearities(prover_data.r1cs, cfg()) + reduce_linearities(r1cs, cfg()) }; println!("Final R1cs size: {}", r1cs.constraints().len()); match action { diff --git a/scripts/aby_tests/util.py b/scripts/aby_tests/util.py index 470ccca5..3386519c 100644 --- a/scripts/aby_tests/util.py +++ b/scripts/aby_tests/util.py @@ -1,6 +1,5 @@ import os from subprocess import Popen, PIPE -import sys from typing import List from tqdm import tqdm diff --git a/scripts/build_kahip.zsh b/scripts/build_kahip.zsh index eeee8c68..bd7a72b6 100755 --- a/scripts/build_kahip.zsh +++ b/scripts/build_kahip.zsh @@ -2,7 +2,7 @@ if [[ ! -z ${KAHIP_SOURCE} ]]; then cd ${KAHIP_SOURCE} - ./compile_withcmake.sh + ./compile_withcmake.sh -DCMAKE_BUILD_TYPE=Release -DPARHIP=off else echo "Missing KAHIP_SOURCE environment variable." fi \ No newline at end of file diff --git a/scripts/zokrates_test.zsh b/scripts/zokrates_test.zsh index 4e8e6dff..7e9197fd 100755 --- a/scripts/zokrates_test.zsh +++ b/scripts/zokrates_test.zsh @@ -35,20 +35,26 @@ function r1cs_test_count { # Test prove workflow, given an example name function pf_test { - ex_name=$1 - $BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup - $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove - $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify - rm -rf P V pi + for proof_impl in groth16 mirage + do + ex_name=$1 + $BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action setup --proof-impl $proof_impl + $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove --proof-impl $proof_impl + $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify --proof-impl $proof_impl + rm -rf P V pi + done } # Test prove workflow with --z-isolate-asserts, given an example name function pf_test_isolate { - ex_name=$1 - $BIN --zsharp-isolate-asserts true examples/ZoKrates/pf/$ex_name.zok r1cs --action setup - $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove - $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify - rm -rf P V pi + for proof_impl in groth16 mirage + do + ex_name=$1 + $BIN --zsharp-isolate-asserts true examples/ZoKrates/pf/$ex_name.zok r1cs --action setup --proof-impl $proof_impl + $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.pin --action prove --proof-impl $proof_impl + $ZK_BIN --inputs examples/ZoKrates/pf/$ex_name.zok.vin --action verify --proof-impl $proof_impl + rm -rf P V pi + done } r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120 diff --git a/src/ir/opt/mem/ram.rs b/src/ir/opt/mem/ram.rs index a5770eec..6bec5ba4 100644 --- a/src/ir/opt/mem/ram.rs +++ b/src/ir/opt/mem/ram.rs @@ -385,7 +385,6 @@ mod test { assert_eq!(bv_lit(1, 3), rams[0].accesses[0].val); assert_eq!(bv_lit(2, 3), rams[0].accesses[1].val); assert!(rams[0].accesses[2].val.is_var()); - dbg!(cs2); } #[test] diff --git a/src/ir/opt/scalarize_vars.rs b/src/ir/opt/scalarize_vars.rs index 14644a20..6c4e28c1 100644 --- a/src/ir/opt/scalarize_vars.rs +++ b/src/ir/opt/scalarize_vars.rs @@ -105,7 +105,7 @@ pub fn assert_all_vars_are_scalars(cs: &Computation) { /// Check that every variables is a scalar. fn remove_non_scalar_vars_from_main_computation(cs: &mut Computation) { - for input in cs.metadata.ordered_public_inputs() { + for input in cs.metadata.ordered_inputs() { if !check(&input).is_scalar() { cs.metadata.remove_var(input.as_var_name()); } diff --git a/src/ir/term/extras.rs b/src/ir/term/extras.rs index 5fba4ab2..7f754125 100644 --- a/src/ir/term/extras.rs +++ b/src/ir/term/extras.rs @@ -161,3 +161,35 @@ pub fn array_elements(t: &Term) -> Vec { pub fn array_to_tuple(t: &Term) -> Term { term(Op::Tuple, array_elements(t)) } + +/// Print operator stats +pub fn dump_op_stats() { + use std::mem::size_of; + println!("Op size: {}bytes", size_of::()); + let mut counts: FxHashMap = FxHashMap::default(); + let mut children: FxHashMap = FxHashMap::default(); + let add = |map: &mut FxHashMap, key: &Op, value: usize| { + if let Some(present_value) = map.get_mut(key) { + *present_value += value; + } else { + map.insert(key.clone(), value); + } + }; + hc::Table::for_each(|op, cs| { + add(&mut counts, op, 1); + add(&mut children, op, cs.len()); + }); + let mut vector = Vec::new(); + for k in counts.keys() { + let ct = *counts.get(k).unwrap(); + let cs_ct = *children.get(k).unwrap(); + let ave = cs_ct as f64 / ct as f64; + vector.push((k.clone(), ct, cs_ct, ave)); + } + vector.sort_by_key(|t| t.1); + for (k, ct, cs_ct, ave) in vector { + let mem = size_of::() * ct + size_of::>() * ct + size_of::() * cs_ct; + let s: String = format!("{k}"); + println!("Op {s:>20}, Count {ct:>8}, Children {cs_ct:>8}, Ave {ave:>8.2}, Mem {mem:>20}"); + } +} diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index 62be3107..b1997ccd 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -23,6 +23,8 @@ struct IrFormatter<'a, 'b> { pub struct IrCfg { /// Whether to introduce a default modulus. pub use_default_field: bool, + /// Whether to show any moduli + pub hide_field: bool, } impl IrCfg { @@ -33,6 +35,7 @@ impl IrCfg { if is_cfg_set() { Self { use_default_field: cfg().fmt.use_default_field, + hide_field: cfg().fmt.hide_field, } } else { Self::parseable() @@ -42,6 +45,7 @@ impl IrCfg { pub fn parseable() -> Self { Self { use_default_field: true, + hide_field: false, } } } @@ -181,6 +185,7 @@ impl DisplayIr for Op { Op::IntNaryOp(a) => write!(f, "{a}"), Op::IntBinPred(a) => write!(f, "{a}"), Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()), + Op::PfChallenge(n, m) => write!(f, "(challenge {} {})", n, m.modulus()), Op::Select => write!(f, "select"), Op::Store => write!(f, "store"), Op::Tuple => write!(f, "tuple"), @@ -244,10 +249,10 @@ impl DisplayIr for Array { impl DisplayIr for FieldV { fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult { - let omit_field = f - .default_field - .as_ref() - .map_or(false, |field| field == &self.ty()); + let omit_field = f.cfg.hide_field + || f.default_field + .as_ref() + .map_or(false, |field| field == &self.ty()); let mut i = self.i(); let mod_bits = self.modulus().significant_bits(); if i.significant_bits() + 1 >= mod_bits { @@ -294,6 +299,12 @@ impl DisplayIr for VariableMetadata { if self.random { write!(f, " (random)")?; } + if 0 != self.round { + write!(f, " (round {})", self.round)?; + } + if self.random { + write!(f, " (random)")?; + } write!(f, ")") } } @@ -333,7 +344,7 @@ fn fmt_term_with_bindings(t: &Term, f: &mut IrFormatter) -> FmtResult { } }) .collect(); - if fields.len() == 1 { + if fields.len() == 1 && !f.cfg.hide_field { f.default_field = fields.into_iter().next(); let i = f.default_field.clone().unwrap(); writeln!(f, "(set-default-modulus {}", i.modulus())?; @@ -408,26 +419,20 @@ impl Display for Term { } } -impl Display for Sort { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg())) - } +macro_rules! fmt_impl { + ($trait:ty, $ty:ty) => { + impl $trait for $ty { + fn fmt(&self, f: &mut Formatter) -> FmtResult { + self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg())) + } + } + }; } -impl Display for Value { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg())) - } -} - -impl Display for Op { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg())) - } -} - -impl Display for ComputationMetadata { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - self.ir_fmt(&mut IrFormatter::new(f, &IrCfg::from_circ_cfg())) - } -} +fmt_impl!(Display, Sort); +fmt_impl!(Display, Value); +fmt_impl!(Display, Op); +fmt_impl!(Display, ComputationMetadata); +fmt_impl!(Debug, Value); +fmt_impl!(Debug, Sort); +fmt_impl!(Debug, Op); diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index e358ae40..8f09515f 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -48,7 +48,7 @@ pub mod ty; pub use bv::BitVector; pub use ty::{check, check_rec, TypeError, TypeErrorReason}; -#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] /// An operator pub enum Op { /// a variable @@ -125,6 +125,13 @@ pub enum Op { /// /// Takes the modulus. UbvToPf(FieldT), + /// A random value, sampled uniformly and independently of its arguments. + /// + /// Takes a name (if deterministically sampled, challenges of different names are sampled + /// differentely) and a field to sample from. + /// + /// In IR evaluation, we sample deterministically based on a hash of the name. + PfChallenge(String, FieldT), /// Integer n-ary operator IntNaryOp(IntNaryOp), @@ -272,6 +279,7 @@ impl Op { Op::FpToFp(_) => Some(1), Op::PfUnOp(_) => Some(1), Op::PfNaryOp(_) => None, + Op::PfChallenge(_, _) => None, Op::IntNaryOp(_) => None, Op::IntBinPred(_) => Some(2), Op::UbvToPf(_) => Some(1), @@ -635,7 +643,7 @@ impl<'de> Deserialize<'de> for Term { } } -#[derive(Clone, PartialEq, Debug, PartialOrd, Serialize, Deserialize)] +#[derive(Clone, PartialEq, PartialOrd, Serialize, Deserialize)] /// An IR value (aka literal) pub enum Value { /// Bit-vector @@ -776,7 +784,7 @@ impl std::hash::Hash for Value { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] /// The "type" of an IR term pub enum Sort { /// bit-vectors of this width @@ -1211,54 +1219,49 @@ pub fn eval(t: &Term, h: &FxHashMap) -> Value { } /// Helper function for eval function. Handles a single term -fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> Value { - let v = match &c.op() { - Op::Var(n, _) => h +fn eval_value(vs: &mut TermMap, h: &FxHashMap, t: Term) -> Value { + let args: Vec<&Value> = t.cs().iter().map(|c| vs.get(c).unwrap()).collect(); + let v = eval_op(t.op(), &args, h); + debug!("Eval {}\nAs {}", t, v); + vs.insert(t, v.clone()); + v +} + +/// Helper function for eval function. Handles a single op +#[allow(clippy::uninlined_format_args)] +pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> Value { + match op { + Op::Var(n, _) => var_vals .get(n) - .unwrap_or_else(|| panic!("Missing var: {} in {:?}", n, h)) + .unwrap_or_else(|| panic!("Missing var: {} in {:?}", n, var_vals)) .clone(), - Op::Eq => Value::Bool(vs.get(&c.cs()[0]).unwrap() == vs.get(&c.cs()[1]).unwrap()), - Op::Not => Value::Bool(!vs.get(&c.cs()[0]).unwrap().as_bool()), - Op::Implies => Value::Bool( - !vs.get(&c.cs()[0]).unwrap().as_bool() || vs.get(&c.cs()[1]).unwrap().as_bool(), - ), - Op::BoolNaryOp(BoolNaryOp::Or) => { - Value::Bool(c.cs().iter().any(|c| vs.get(c).unwrap().as_bool())) - } - Op::BoolNaryOp(BoolNaryOp::And) => { - Value::Bool(c.cs().iter().all(|c| vs.get(c).unwrap().as_bool())) - } + Op::Eq => Value::Bool(args[0] == args[1]), + Op::Not => Value::Bool(!args[0].as_bool()), + Op::Implies => Value::Bool(!args[0].as_bool() || args[1].as_bool()), + Op::BoolNaryOp(BoolNaryOp::Or) => Value::Bool(args.iter().any(|a| a.as_bool())), + Op::BoolNaryOp(BoolNaryOp::And) => Value::Bool(args.iter().all(|a| a.as_bool())), Op::BoolNaryOp(BoolNaryOp::Xor) => Value::Bool( - c.cs() - .iter() - .map(|c| vs.get(c).unwrap().as_bool()) + args.iter() + .map(|a| a.as_bool()) .fold(false, std::ops::BitXor::bitxor), ), - Op::BvBit(i) => Value::Bool( - vs.get(&c.cs()[0]) - .unwrap() - .as_bv() - .uint() - .get_bit(*i as u32), - ), + Op::BvBit(i) => Value::Bool(args[0].as_bv().uint().get_bit(*i as u32)), Op::BoolMaj => { - let c0 = vs.get(&c.cs()[0]).unwrap().as_bool() as u8; - let c1 = vs.get(&c.cs()[1]).unwrap().as_bool() as u8; - let c2 = vs.get(&c.cs()[2]).unwrap().as_bool() as u8; + let c0 = args[0].as_bool() as u8; + let c1 = args[1].as_bool() as u8; + let c2 = args[2].as_bool() as u8; Value::Bool(c0 + c1 + c2 > 1) } Op::BvConcat => Value::BitVector({ - let mut it = c.cs().iter().map(|c| vs.get(c).unwrap().as_bv().clone()); + let mut it = args.iter().map(|a| a.as_bv().clone()); let f = it.next().unwrap(); it.fold(f, BitVector::concat) }), - Op::BvExtract(h, l) => { - Value::BitVector(vs.get(&c.cs()[0]).unwrap().as_bv().clone().extract(*h, *l)) - } + Op::BvExtract(h, l) => Value::BitVector(args[0].as_bv().clone().extract(*h, *l)), Op::Const(v) => v.clone(), Op::BvBinOp(o) => Value::BitVector({ - let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone(); - let b = vs.get(&c.cs()[1]).unwrap().as_bv().clone(); + let a = args[0].as_bv().clone(); + let b = args[1].as_bv().clone(); match o { BvBinOp::Udiv => a / &b, BvBinOp::Urem => a % &b, @@ -1269,14 +1272,14 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> } }), Op::BvUnOp(o) => Value::BitVector({ - let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone(); + let a = args[0].as_bv().clone(); match o { BvUnOp::Not => !a, BvUnOp::Neg => -a, } }), Op::BvNaryOp(o) => Value::BitVector({ - let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_bv().clone()); + let mut xs = args.iter().map(|a| a.as_bv().clone()); let f = xs.next().unwrap(); xs.fold( f, @@ -1290,13 +1293,13 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> ) }), Op::BvSext(w) => Value::BitVector({ - let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone(); + let a = args[0].as_bv().clone(); let mask = ((Integer::from(1) << *w as u32) - 1) * Integer::from(a.uint().get_bit(a.width() as u32 - 1)); BitVector::new(a.uint() | (mask << a.width() as u32), a.width() + w) }), Op::PfToBv(w) => Value::BitVector({ - let i = vs.get(&c.cs()[0]).unwrap().as_pf().i(); + let i = args[0].as_pf().i(); if let FieldToBv::Panic = cfg().ir.field_to_bv { assert!( (i.significant_bits() as usize) <= *w, @@ -1307,19 +1310,13 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> BitVector::new(i % (Integer::from(1) << *w), *w) }), Op::BvUext(w) => Value::BitVector({ - let a = vs.get(&c.cs()[0]).unwrap().as_bv().clone(); + let a = args[0].as_bv().clone(); BitVector::new(a.uint().clone(), a.width() + w) }), - Op::Ite => if vs.get(&c.cs()[0]).unwrap().as_bool() { - vs.get(&c.cs()[1]) - } else { - vs.get(&c.cs()[2]) - } - .unwrap() - .clone(), + Op::Ite => args[if args[0].as_bool() { 1 } else { 2 }].clone(), Op::BvBinPred(o) => Value::Bool({ - let a = vs.get(&c.cs()[0]).unwrap().as_bv(); - let b = vs.get(&c.cs()[1]).unwrap().as_bv(); + let a = args[0].as_bv(); + let b = args[1].as_bv(); match o { BvBinPred::Sge => a.as_sint() >= b.as_sint(), BvBinPred::Sgt => a.as_sint() > b.as_sint(), @@ -1331,12 +1328,9 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> BvBinPred::Ult => a.uint() < b.uint(), } }), - Op::BoolToBv => Value::BitVector(BitVector::new( - Integer::from(vs.get(&c.cs()[0]).unwrap().as_bool()), - 1, - )), + Op::BoolToBv => Value::BitVector(BitVector::new(Integer::from(args[0].as_bool()), 1)), Op::PfUnOp(o) => Value::Field({ - let a = vs.get(&c.cs()[0]).unwrap().as_pf().clone(); + let a = args[0].as_pf().clone(); match o { PfUnOp::Recip => { if a.is_zero() { @@ -1349,7 +1343,7 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> } }), Op::PfNaryOp(o) => Value::Field({ - let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_pf().clone()); + let mut xs = args.iter().map(|a| a.as_pf().clone()); let f = xs.next().unwrap(); xs.fold( f, @@ -1360,8 +1354,8 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> ) }), Op::IntBinPred(o) => Value::Bool({ - let a = vs.get(&c.cs()[0]).unwrap().as_int(); - let b = vs.get(&c.cs()[1]).unwrap().as_int(); + let a = args[0].as_int(); + let b = args[1].as_int(); match o { IntBinPred::Ge => a >= b, IntBinPred::Gt => a > b, @@ -1370,7 +1364,7 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> } }), Op::IntNaryOp(o) => Value::Int({ - let mut xs = c.cs().iter().map(|c| vs.get(c).unwrap().as_int().clone()); + let mut xs = args.iter().map(|a| a.as_int().clone()); let f = xs.next().unwrap(); xs.fold( f, @@ -1380,77 +1374,94 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> }, ) }), - Op::UbvToPf(fty) => Value::Field(fty.new_v(vs.get(&c.cs()[0]).unwrap().as_bv().uint())), + Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), + Op::PfChallenge(name, field) => { + use rand::SeedableRng; + use rand_chacha::ChaChaRng; + use std::hash::{Hash, Hasher}; + // hash the string + let mut hasher = fxhash::FxHasher::default(); + name.hash(&mut hasher); + let hash: u64 = hasher.finish(); + // seed ChaCha with the hash + let mut seed = [0u8; 32]; + seed[0..8].copy_from_slice(&hash.to_le_bytes()); + let mut rng = ChaChaRng::from_seed(seed); + // sample from ChaCha + Value::Field(field.random_v(&mut rng)) + } // tuple - Op::Tuple => Value::Tuple(c.cs().iter().map(|c| vs.get(c).unwrap().clone()).collect()), + Op::Tuple => Value::Tuple(args.iter().map(|a| (*a).clone()).collect()), Op::Field(i) => { - let t = vs.get(&c.cs()[0]).unwrap().as_tuple(); - assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs()[0]); + let t = args[0].as_tuple(); + assert!(i < &t.len(), "{} out of bounds for {} on {:?}", i, op, args); t[*i].clone() } Op::Update(i) => { - let mut t = Vec::from(vs.get(&c.cs()[0]).unwrap().as_tuple()).into_boxed_slice(); - assert!(i < &t.len(), "{} out of bounds for {}", i, c.cs()[0]); - let e = vs.get(&c.cs()[1]).unwrap().clone(); + let mut t = Vec::from(args[0].as_tuple()).into_boxed_slice(); + assert!(i < &t.len(), "{} out of bounds for {} on {:?}", i, op, args); + let e = args[1].clone(); assert_eq!(t[*i].sort(), e.sort()); t[*i] = e; Value::Tuple(t) } // array Op::Store => { - let a = vs.get(&c.cs()[0]).unwrap().as_array().clone(); - let i = vs.get(&c.cs()[1]).unwrap().clone(); - let v = vs.get(&c.cs()[2]).unwrap().clone(); + let a = args[0].as_array().clone(); + let i = args[1].clone(); + let v = args[2].clone(); Value::Array(a.store(i, v)) } Op::Select => { - let a = vs.get(&c.cs()[0]).unwrap().as_array().clone(); - let i = vs.get(&c.cs()[1]).unwrap(); + let a = args[0].as_array().clone(); + let i = args[1]; a.select(i) } - Op::Map(op) => { - let arg_cnt = c.cs().len(); - + Op::Map(inner_op) => { // term_vecs[i] will store a vector of all the i-th index entries of the array arguments - let mut term_vecs = vec![Vec::new(); vs.get(&c.cs()[0]).unwrap().as_array().size]; + let mut arg_vecs: Vec> = vec![Vec::new(); args[0].as_array().size]; - for i in 0..arg_cnt { - let arr = vs.get(&c.cs()[i]).unwrap().as_array().clone(); - let iter = match check(&c.cs()[i]) { + for arg in args { + let arr = arg.as_array().clone(); + let iter = match arg.sort() { Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(), _ => panic!("Input type should be Array"), }; for (j, jval) in iter { - term_vecs[j].push(leaf_term(Op::Const(arr.clone().select(&jval)))) + arg_vecs[j].push(arr.select(&jval)) } } - - let mut res = match check(&c) { - Sort::Array(k, v, n) => Array::default((*k).clone(), &v, n), + let term = term( + op.clone(), + args.iter() + .map(|a| leaf_term(Op::Const((*a).clone()))) + .collect(), + ); + let (mut res, iter) = match check(&term) { + Sort::Array(k, v, n) => ( + Array::default((*k).clone(), &v, n), + (*k).clone().elems_iter_values().take(n).enumerate(), + ), _ => panic!("Output type of map should be array"), }; - let iter = match check(&c) { - Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(), - _ => panic!("Input type should be Array"), - }; for (i, idxval) in iter { - let t = term((**op).clone(), term_vecs[i].clone()); - let val = eval_value(vs, h, t); + let args: Vec<&Value> = arg_vecs[i].iter().collect(); + let val = eval_op(inner_op, &args, var_vals); res.map.insert(idxval, val); } Value::Array(res) } Op::Rot(i) => { - let a = vs.get(&c.cs()[0]).unwrap().as_array().clone(); - let iter = match check(&c.cs()[0]) { - Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(), + let a = args[0].as_array().clone(); + let (mut res, iter, len) = match args[0].sort() { + Sort::Array(k, v, n) => ( + Array::default((*k).clone(), &v, n), + (*k).clone().elems_iter_values().take(n).enumerate(), + n, + ), _ => panic!("Input type should be Array"), }; - let (mut res, len) = match check(&c.cs()[0]) { - Sort::Array(k, v, n) => (Array::default((*k).clone(), &v, n), n), - _ => panic!("Output type of rot should be Array"), - }; // calculate new rotation amount let rot = *i % len; @@ -1464,10 +1475,7 @@ fn eval_value(vs: &mut TermMap, h: &FxHashMap, c: Term) -> } o => unimplemented!("eval: {:?}", o), - }; - vs.insert(c.clone(), v.clone()); - debug!("Eval {}\nAs {}", c, v); - v + } } /// Make an array from a sequence of terms. @@ -1562,6 +1570,17 @@ impl PostOrderIter { visited: TermSet::default(), } } + /// Make an iterator over the descendents of `roots`, stopping at `skips`. + pub fn from_roots_and_skips(roots: impl IntoIterator, skips: TermSet) -> Self { + Self { + stack: roots + .into_iter() + .filter(|t| !skips.contains(t)) + .map(|t| (false, t)) + .collect(), + visited: skips, + } + } } impl std::iter::Iterator for PostOrderIter { @@ -1709,6 +1728,47 @@ impl ComputationMetadata { out } + /// Get the interactive structure of the variables. See [InteractiveVars]. + pub fn interactive_vars(&self) -> InteractiveVars { + let final_round = self.vars.values().map(|m| m.round).max().unwrap_or(0); + let mut instances = Vec::new(); + let mut rounds = vec![RoundVars::default(); final_round as usize + 1]; + for meta in self.vars.values() { + if meta.random { + // is this a challenge?, if so it must be public + assert!(meta.vis.is_none()); + rounds[meta.round as usize].challenges.push(meta.term()); + } else if meta.vis.is_none() { + // is it a public non-challenge? if so, it must be round 0 + assert!(meta.round == 0); + instances.push(meta.term()); + } else { + // this is a witness + rounds[meta.round as usize].witnesses.push(meta.term()); + } + } + // If there no final challenges, distinguish the last round of witnesses + let final_witnesses = if rounds.last().unwrap().challenges.is_empty() { + rounds.pop().unwrap().witnesses + } else { + Vec::new() + }; + let mut ret = InteractiveVars { + instances, + rounds, + final_witnesses, + }; + // sort! + let cmp_name = |a: &Term, b: &Term| a.as_var_name().cmp(b.as_var_name()); + ret.instances.sort_by(cmp_name); + for round in &mut ret.rounds { + round.witnesses.sort_by(cmp_name); + round.challenges.sort_by(cmp_name); + } + ret.final_witnesses.sort_by(cmp_name); + ret + } + /// Get a round after the rounds of these variables pub fn future_round<'a, Q: Borrow + 'a>( &self, @@ -1757,6 +1817,29 @@ impl ComputationMetadata { } } +/// A structured collection of variables that indicates the round structure: e.g., orderings, +/// challenges. +/// +/// It represents the variables themselves as terms. +#[derive(Default)] +pub struct InteractiveVars { + /// Instance vars + pub instances: Vec, + /// Rounds + pub rounds: Vec, + /// Final witnesses + pub final_witnesses: Vec, +} + +/// Witnesses, followed by a challenge. +#[derive(Default, Clone)] +pub struct RoundVars { + /// witnesses + pub witnesses: Vec, + /// followed by challenges + pub challenges: Vec, +} + #[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] /// An IR computation. pub struct Computation { diff --git a/src/ir/term/precomp.rs b/src/ir/term/precomp.rs index 948208fa..c79e0e8c 100644 --- a/src/ir/term/precomp.rs +++ b/src/ir/term/precomp.rs @@ -123,6 +123,19 @@ impl PreComp { self.recompute_inputs(); self } + + /// Reduce the precomputation to a single, step-less map. + pub fn flatten(self) -> FxHashMap { + let mut out: FxHashMap = Default::default(); + let mut cache: TermMap = Default::default(); + for (name, sort) in &self.sequence { + let term = extras::substitute_cache(self.outputs.get(name).unwrap(), &mut cache); + let var_term = leaf_term(Op::Var(name.clone(), sort.clone())); + out.insert(name.into(), term.clone()); + cache.insert(var_term, term); + } + out + } } #[cfg(test)] diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index c1fd2f40..fe1e9eaf 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -288,6 +288,10 @@ impl<'src> IrInterp<'src> { [Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))), [Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))), [Leaf(Ident, b"fp2fp"), a] => Ok(Op::FpToFp(self.usize(a))), + [Leaf(Ident, b"challenge"), name, field] => Ok(Op::PfChallenge( + self.ident_string(name), + FieldT::from(self.int(field)), + )), [Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))), [Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))), [Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))), @@ -688,6 +692,10 @@ impl<'src> IrInterp<'src> { let inputs = self.var_decl_list(&tts[1]); let outputs = self.var_decl_list(&tts[2]); let tuple_term = self.term(&tts[3]); + assert!( + matches!(check(&tuple_term), Sort::Tuple(..)), + "precompute output term must be a tuple" + ); assert!( outputs.len() == tuple_term.cs().len(), "output list has {} items, tuple has {}", @@ -958,4 +966,12 @@ mod test { let c2 = parse_precompute(s.as_bytes()); assert_eq!(c, c2); } + + #[test] + fn challenge_roundtrip() { + let t = parse_term(b"(declare ((a bool) (b bool)) ((challenge hithere 17) a b))"); + let s = serialize_term(&t); + let t2 = parse_term(s.as_bytes()); + assert_eq!(t, t2); + } } diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index d9dd3a59..5ee68102 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -59,6 +59,7 @@ fn check_dependencies(t: &Term) -> Vec { Op::IntNaryOp(_) => Vec::new(), Op::IntBinPred(_) => Vec::new(), Op::UbvToPf(_) => Vec::new(), + Op::PfChallenge(_, _) => Vec::new(), Op::Select => vec![t.cs()[0].clone()], Op::Store => vec![t.cs()[0].clone()], Op::Tuple => t.cs().to_vec(), @@ -128,6 +129,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { Op::IntNaryOp(_) => Ok(Sort::Int), Op::IntBinPred(_) => Ok(Sort::Bool), Op::UbvToPf(m) => Ok(Sort::Field(m.clone())), + Op::PfChallenge(_, m) => Ok(Sort::Field(m.clone())), Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v)| v.clone()), Op::Store => Ok(get_ty(&t.cs()[0]).clone()), Op::Tuple => Ok(Sort::Tuple(t.cs().iter().map(get_ty).cloned().collect())), @@ -332,6 +334,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())), + (Op::PfChallenge(_, m), _) => Ok(Sort::Field(m.clone())), (Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()), (Op::IntNaryOp(_), a) => { let ctx = "int nary op"; diff --git a/src/target/r1cs/bellman.rs b/src/target/r1cs/bellman.rs index dbbf99ae..3a707f55 100644 --- a/src/target/r1cs/bellman.rs +++ b/src/target/r1cs/bellman.rs @@ -1,31 +1,28 @@ //! Exporting our R1CS to bellman -use ::bellman::{ - groth16::{ - create_random_proof, generate_random_parameters, prepare_verifying_key, verify_proof, - Parameters, Proof, VerifyingKey, - }, - Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable, -}; -use bincode::{deserialize_from, serialize_into}; +use ::bellman::{groth16, Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable}; use ff::{Field, PrimeField, PrimeFieldBits}; use fxhash::FxHashMap; use gmp_mpfr_sys::gmp::limb_t; use group::WnafGroup; use log::debug; use pairing::{Engine, MultiMillerLoop}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs::File; -use std::io::{self, BufRead, BufReader}; +use std::io::{BufRead, BufReader}; +use std::marker::PhantomData; use std::path::Path; use std::str::FromStr; use rug::integer::{IsPrime, Order}; use rug::Integer; -use super::*; +use super::proof; +use super::{wit_comp::StagedWitCompEvaluator, Lc, ProverData, Var, VarType, VerifierData}; +use crate::ir::term::Value; /// Convert a (rug) integer to a prime field element. -fn int_to_ff(i: Integer) -> F { +pub(super) fn int_to_ff(i: Integer) -> F { let mut accumulator = F::from(0); let limb_bits = (std::mem::size_of::() as u64) << 3; let limb_base = F::from(2).pow_vartime([limb_bits]); @@ -39,8 +36,8 @@ fn int_to_ff(i: Integer) -> F { /// Convert one our our linear combinations to a bellman linear combination. /// Takes a zero linear combination. We could build it locally, but bellman provides one, so... -fn lc_to_bellman>( - vars: &HashMap, +pub(super) fn lc_to_bellman>( + vars: &HashMap, lc: &Lc, zero_lc: LinearCombination, ) -> LinearCombination { @@ -59,7 +56,7 @@ fn lc_to_bellman>( } // hmmm... this should work essentially all the time, I think -fn get_modulus() -> Integer { +pub(super) fn get_modulus() -> Integer { let neg_1_f = -F::one(); let p_lsf: Integer = Integer::from_digits(neg_1_f.to_repr().as_ref(), Order::Lsf) + 1; let p_msf: Integer = Integer::from_digits(neg_1_f.to_repr().as_ref(), Order::Msf) + 1; @@ -76,7 +73,7 @@ fn get_modulus() -> Integer { /// /// Optionally contains a variable value map. This must be populated to use the /// bellman prover. -pub struct SynthInput<'a>(&'a R1cs, &'a Option>); +pub struct SynthInput<'a>(&'a ProverData, Option<&'a FxHashMap>); impl<'a, F: PrimeField> Circuit for SynthInput<'a> { #[track_caller] @@ -86,55 +83,53 @@ impl<'a, F: PrimeField> Circuit for SynthInput<'a> { { let f_mod = get_modulus::(); assert_eq!( - self.0.modulus.modulus(), + self.0.r1cs.field.modulus(), &f_mod, "\nR1CS has modulus \n{},\n but Bellman CS expects \n{}", - self.0.modulus, + self.0.r1cs.field, f_mod ); - let mut uses = HashMap::with_capacity(self.0.next_idx); - for (a, b, c) in self.0.constraints.iter() { - [a, b, c].iter().for_each(|y| { - y.monomials.keys().for_each(|k| { - uses.get_mut(k) - .map(|i| { - *i += 1; - }) - .or_else(|| { - uses.insert(*k, 1); - None - }); + let mut vars = HashMap::with_capacity(self.0.r1cs.vars.len()); + let values: Option> = self.1.map(|values| { + let mut evaluator = StagedWitCompEvaluator::new(&self.0.precompute); + let mut ffs = Vec::new(); + ffs.extend(evaluator.eval_stage(values.clone()).into_iter().cloned()); + ffs.extend( + evaluator + .eval_stage(Default::default()) + .into_iter() + .cloned(), + ); + ffs + }); + for (i, var) in self.0.r1cs.vars.iter().copied().enumerate() { + assert!( + !matches!(var.ty(), VarType::CWit), + "Bellman doesn't support committed witnesses" + ); + assert!( + !matches!(var.ty(), VarType::RoundWit | VarType::Chall), + "Bellman doesn't support rounds" + ); + let public = matches!(var.ty(), VarType::Inst); + let name_f = || format!("{var:?}"); + let val_f = || { + Ok({ + let i_val = &values.as_ref().expect("missing values")[i]; + let ff_val = int_to_ff(i_val.as_pf().into()); + debug!("value : {var:?} -> {ff_val:?} ({i_val})"); + ff_val }) - }); + }; + debug!("var: {:?}, public: {}", var, public); + let v = if public { + cs.alloc_input(name_f, val_f)? + } else { + cs.alloc(name_f, val_f)? + }; + vars.insert(var, v); } - let mut vars = HashMap::with_capacity(self.0.next_idx); - for i in 0..self.0.next_idx { - if let Some(s) = self.0.idxs_signals.get(&i) { - //for (_i, s) in self.0.idxs_signals.get() { - let public = self.0.public_idxs.contains(&i); - if uses.get(&i).is_some() || public { - let name_f = || s.to_string(); - let val_f = || { - Ok({ - let i_val = self.1.as_ref().expect("missing values").get(s).unwrap(); - let ff_val = int_to_ff(i_val.as_pf().into()); - debug!("value : {} -> {:?} ({})", s, ff_val, i_val); - ff_val - }) - }; - debug!("var: {}, public: {}", s, public); - let v = if public { - cs.alloc_input(name_f, val_f)? - } else { - cs.alloc(name_f, val_f)? - }; - vars.insert(i, v); - } else { - debug!("drop dead var: {}", s); - } - } - } - for (i, (a, b, c)) in self.0.constraints.iter().enumerate() { + for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() { cs.enforce( || format!("con{i}"), |z| lc_to_bellman::(&vars, a, z), @@ -145,7 +140,7 @@ impl<'a, F: PrimeField> Circuit for SynthInput<'a> { debug!( "done with synth: {} vars {} cs", vars.len(), - self.0.constraints.len() + self.0.r1cs.constraints.len() ); Ok(()) } @@ -168,12 +163,6 @@ mod serde_pk { use pairing::Engine; use serde::{Deserialize, Deserializer, Serialize, Serializer}; - #[derive(Serialize)] - pub struct SerPk<'a, E: Engine>(#[serde(with = "self")] pub &'a Parameters); - - #[derive(Deserialize)] - pub struct DePk(#[serde(with = "self")] pub Parameters); - pub fn serialize( p: &Parameters, ser: S, @@ -196,12 +185,6 @@ mod serde_vk { use pairing::Engine; use serde::{Deserialize, Deserializer, Serialize, Serializer}; - #[derive(Serialize)] - pub struct SerVk<'a, E: Engine>(#[serde(with = "self")] pub &'a VerifyingKey); - - #[derive(Deserialize)] - pub struct DeVk(#[serde(with = "self")] pub VerifyingKey); - pub fn serialize( p: &VerifyingKey, ser: S, @@ -219,116 +202,80 @@ mod serde_vk { } } -/// Given -/// * a proving-key path, -/// * a verifying-key path, -/// * prover data, and -/// * verifier data -/// generate parameters and write them and the data to files at those paths. -pub fn gen_params, P2: AsRef>( - pk_path: P1, - vk_path: P2, - p_data: &ProverData, - v_data: &VerifierData, -) -> io::Result<()> +mod serde_pf { + use bellman::groth16::Proof; + use pairing::Engine; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(p: &Proof, ser: S) -> Result { + let mut bs: Vec = Vec::new(); + p.write(&mut bs).unwrap(); + serde_bytes::ByteBuf::from(bs).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(de: D) -> Result, D::Error> { + let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?; + Ok(Proof::read(&**bs).unwrap()) + } +} + +/// The [::bellman] implementation of Groth16. +pub struct Bellman(PhantomData); + +/// The pk for [Bellman] +#[derive(Serialize, Deserialize)] +pub struct ProvingKey( + ProverData, + #[serde(with = "serde_pk")] groth16::Parameters, +); + +/// The vk for [Bellman] +#[derive(Serialize, Deserialize)] +pub struct VerifyingKey( + VerifierData, + #[serde(with = "serde_vk")] groth16::VerifyingKey, +); + +/// The proof for [Bellman] +#[derive(Serialize, Deserialize)] +pub struct Proof(#[serde(with = "serde_pf")] groth16::Proof); + +impl proof::ProofSystem for Bellman where + E: MultiMillerLoop, E::G1: WnafGroup, E::G2: WnafGroup, -{ - let rng = &mut rand::thread_rng(); - let p = generate_random_parameters::(SynthInput(&p_data.r1cs, &None), rng).unwrap(); - write_prover_key_and_data(pk_path, &p, p_data)?; - write_verifier_key_and_data(vk_path, &p.vk, v_data)?; - Ok(()) -} - -fn write_prover_key_and_data, E: Engine>( - path: P, - params: &Parameters, - data: &ProverData, -) -> io::Result<()> { - let mut file = File::create(path)?; - serialize_into(&mut file, &(serde_pk::SerPk(params), &data)).unwrap(); - Ok(()) -} - -fn read_prover_key_and_data, E: Engine>( - path: P, -) -> io::Result<(Parameters, ProverData)> { - let mut file = File::open(path)?; - let (serde_pk::DePk(pk), data): (_, ProverData) = deserialize_from(&mut file).unwrap(); - Ok((pk, data)) -} - -fn write_verifier_key_and_data, E: Engine>( - path: P, - key: &VerifyingKey, - data: &VerifierData, -) -> io::Result<()> { - let mut file = File::create(path)?; - serialize_into(&mut file, &(serde_vk::SerVk(key), &data)).unwrap(); - Ok(()) -} - -fn read_verifier_key_and_data, E: Engine>( - path: P, -) -> io::Result<(VerifyingKey, VerifierData)> { - let mut file = File::open(path)?; - let (serde_vk::DeVk(vk), data): (_, VerifierData) = deserialize_from(&mut file).unwrap(); - Ok((vk, data)) -} - -/// Given -/// * a proving-key path, -/// * a proof path, and -/// * a prover input map -/// generate a random proof and writes it to the path -pub fn prove, P2: AsRef>( - pk_path: P1, - pf_path: P2, - inputs_map: &FxHashMap, -) -> io::Result<()> -where E::Fr: PrimeFieldBits, { - let (pk, prover_data) = read_prover_key_and_data::<_, E>(pk_path)?; - let rng = &mut rand::thread_rng(); - for (input, sort) in &prover_data.precompute_inputs { - let value = inputs_map - .get(input) - .unwrap_or_else(|| panic!("No input for {}", input)); - let sort2 = value.sort(); - assert_eq!( - sort, &sort2, - "Sort mismatch for {input}. Expected\n\t{sort} but got\n\t{sort2}", - ); - } - let new_map = prover_data.precompute.eval(inputs_map); - prover_data.r1cs.check_all(&new_map); - let pf = create_random_proof(SynthInput(&prover_data.r1cs, &Some(new_map)), &pk, rng).unwrap(); - let mut pf_file = File::create(pf_path)?; - pf.write(&mut pf_file)?; - Ok(()) -} + type VerifyingKey = VerifyingKey; -/// Given -/// * a verifying-key path, -/// * a proof path, -/// * and a verifier input map -/// checks the proof at that path -pub fn verify, P2: AsRef>( - vk_path: P1, - pf_path: P2, - inputs_map: &FxHashMap, -) -> io::Result<()> { - let (vk, verifier_data) = read_verifier_key_and_data::<_, E>(vk_path)?; - let pvk = prepare_verifying_key(&vk); - let inputs = verifier_data.eval(inputs_map); - let inputs_as_ff: Vec = inputs.into_iter().map(int_to_ff).collect(); - let mut pf_file = File::open(pf_path).unwrap(); - let pf = Proof::read(&mut pf_file).unwrap(); - verify_proof(&pvk, &pf, &inputs_as_ff).unwrap(); - Ok(()) + type ProvingKey = ProvingKey; + + type Proof = Proof; + + fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey) { + let rng = &mut rand::thread_rng(); + let params = + groth16::generate_random_parameters::(SynthInput(&p_data, None), rng).unwrap(); + let v_params = params.vk.clone(); + (ProvingKey(p_data, params), VerifyingKey(v_data, v_params)) + } + + fn prove(pk: &Self::ProvingKey, witness: &FxHashMap) -> Self::Proof { + let rng = &mut rand::thread_rng(); + pk.0.check_all(witness); + Proof(groth16::create_random_proof(SynthInput(&pk.0, Some(witness)), &pk.1, rng).unwrap()) + } + + fn verify(vk: &Self::VerifyingKey, inst: &FxHashMap, pf: &Self::Proof) -> bool { + let pvk = groth16::prepare_verifying_key(&vk.1); + let r1cs_inst_map = vk.0.eval(inst); + let r1cs_inst: Vec = r1cs_inst_map + .into_iter() + .map(|i| int_to_ff(i.i())) + .collect(); + groth16::verify_proof(&pvk, &pf.0, &r1cs_inst).is_ok() + } } #[cfg(test)] diff --git a/src/target/r1cs/mirage.rs b/src/target/r1cs/mirage.rs new file mode 100644 index 00000000..4f1f5281 --- /dev/null +++ b/src/target/r1cs/mirage.rs @@ -0,0 +1,356 @@ +//! Exporting our R1CS to bellman +use ::bellman::{ + cc::{CcCircuit, CcConstraintSystem}, + mirage, SynthesisError, +}; +use ff::{PrimeField, PrimeFieldBits}; +use fxhash::FxHashMap; +use group::WnafGroup; +use log::debug; +use pairing::{Engine, MultiMillerLoop}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::marker::PhantomData; +use std::path::Path; +use std::str::FromStr; + +use rug::Integer; + +use super::proof; +use super::{wit_comp::StagedWitCompEvaluator, ProverData, VarType, VerifierData}; +use crate::ir::term::Value; + +use super::bellman::{get_modulus, int_to_ff, lc_to_bellman}; + +fn ff_to_int(f: F) -> Integer { + let mut buffer = vec![]; + use std::io::Read; + f.to_le_bits() + .as_bitslice() + .read_to_end(&mut buffer) + .unwrap(); + Integer::from_digits(&buffer, rug::integer::Order::Lsf) +} + +/// A synthesizable bellman circuit. +/// +/// Optionally contains a variable value map. This must be populated to use the +/// bellman prover. +pub struct SynthInput<'a>(&'a ProverData, Option<&'a FxHashMap>); + +impl<'a, F: PrimeField + PrimeFieldBits> CcCircuit for SynthInput<'a> { + #[track_caller] + fn synthesize(self, cs: &mut CS) -> std::result::Result<(), SynthesisError> + where + CS: CcConstraintSystem, + { + let f_mod = get_modulus::(); + assert_eq!( + self.0.r1cs.field.modulus(), + &f_mod, + "\nR1CS has modulus \n{},\n but mirage CS expects \n{}", + self.0.r1cs.field, + f_mod + ); + let mut vars = HashMap::with_capacity(self.0.r1cs.vars.len()); + // (assignment values, evaluator, next evaluator inputs) + let mut wit_comp: Option<( + Vec, + StagedWitCompEvaluator<'a>, + FxHashMap, + )> = self.1.map(|inputs| { + ( + Vec::new(), + StagedWitCompEvaluator::new(&self.0.precompute), + inputs.clone(), + ) + }); + let mut var_idx = 0; + let num_stages = self.0.precompute.stage_sizes().count(); + for (i, num_vars) in self.0.precompute.stage_sizes().enumerate() { + if let Some((ref mut var_values, ref mut evaluator, ref mut inputs)) = wit_comp.as_mut() + { + var_values.extend( + evaluator + .eval_stage(std::mem::take(inputs)) + .into_iter() + .cloned(), + ); + } + let num_challs = if i + 1 < num_stages { + self.0.precompute.num_stage_inputs(i + 1) + } else { + 0 + }; + for j in 0..(num_vars + num_challs) { + let var = self.0.r1cs.vars[var_idx]; + let name_f = || format!("{var:?}"); + let val_f = || { + Ok({ + let i_val = &wit_comp.as_ref().expect("missing values").0[var_idx]; + let ff_val = int_to_ff(i_val.as_pf().into()); + debug!("value : {var:?} -> {ff_val:?} ({i_val})"); + ff_val + }) + }; + let v = match var.ty() { + VarType::Inst => cs.alloc_input(name_f, val_f)?, + VarType::RoundWit => cs.alloc(name_f, val_f)?, + VarType::FinalWit => cs.alloc(name_f, val_f)?, + VarType::Chall => { + let (v, val) = cs.alloc_random(name_f)?; + if let Some((ref mut values, _, ref mut inputs)) = wit_comp.as_mut() { + let val = + Value::Field(self.0.r1cs.field.new_v(ff_to_int(val.unwrap()))); + values.push(val.clone()); + let name = self.0.r1cs.names.get(&var).unwrap(); + inputs.insert(name.to_owned(), val); + } + v + } + VarType::CWit => unimplemented!(), + }; + vars.insert(var, v); + var_idx += 1; + if j + 1 == num_vars && num_challs > 0 { + cs.end_aux_block(|| format!("block {}", i - 1))?; + } + } + } + + for (i, (a, b, c)) in self.0.r1cs.constraints.iter().enumerate() { + cs.enforce( + || format!("con{i}"), + |z| lc_to_bellman::(&vars, a, z), + |z| lc_to_bellman::(&vars, b, z), + |z| lc_to_bellman::(&vars, c, z), + ); + } + debug!( + "done with synth: {} vars {} cs", + vars.len(), + self.0.r1cs.constraints.len() + ); + Ok(()) + } + + fn num_aux_blocks(&self) -> usize { + self.0.precompute.stage_sizes().count() - 2 + } +} + +/// Convert a (rug) integer to a prime field element. +pub fn parse_instance, F: PrimeField>(path: P) -> Vec { + let f = BufReader::new(File::open(path).unwrap()); + f.lines() + .map(|line| { + let s = line.unwrap(); + let i = Integer::from_str(s.trim()).unwrap(); + int_to_ff(i) + }) + .collect() +} + +mod serde_pk { + use bellman::mirage::Parameters; + use pairing::Engine; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize( + p: &Parameters, + ser: S, + ) -> Result { + let mut bs: Vec = Vec::new(); + p.write(&mut bs).unwrap(); + serde_bytes::ByteBuf::from(bs).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>( + de: D, + ) -> Result, D::Error> { + let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?; + Ok(Parameters::read(&**bs, false).unwrap()) + } +} + +mod serde_vk { + use bellman::mirage::VerifyingKey; + use pairing::Engine; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize( + p: &VerifyingKey, + ser: S, + ) -> Result { + let mut bs: Vec = Vec::new(); + p.write(&mut bs).unwrap(); + serde_bytes::ByteBuf::from(bs).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>( + de: D, + ) -> Result, D::Error> { + let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?; + Ok(VerifyingKey::read(&**bs).unwrap()) + } +} + +mod serde_pf { + use bellman::mirage::Proof; + use pairing::Engine; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(p: &Proof, ser: S) -> Result { + let mut bs: Vec = Vec::new(); + p.write(&mut bs).unwrap(); + serde_bytes::ByteBuf::from(bs).serialize(ser) + } + + pub fn deserialize<'de, D: Deserializer<'de>, E: Engine>(de: D) -> Result, D::Error> { + let bs: serde_bytes::ByteBuf = Deserialize::deserialize(de)?; + Ok(Proof::read(&**bs).unwrap()) + } +} + +/// The [::bellman] implementation of Groth16. +pub struct Mirage(PhantomData); + +/// The pk for [mirage] +#[derive(Serialize, Deserialize)] +pub struct ProvingKey( + ProverData, + #[serde(with = "serde_pk")] mirage::Parameters, +); + +/// The vk for [mirage] +#[derive(Serialize, Deserialize)] +pub struct VerifyingKey( + VerifierData, + #[serde(with = "serde_vk")] mirage::VerifyingKey, +); + +/// The proof for [mirage] +#[derive(Serialize, Deserialize)] +pub struct Proof(#[serde(with = "serde_pf")] mirage::Proof); + +impl proof::ProofSystem for Mirage +where + E: MultiMillerLoop, + E::G1: WnafGroup, + E::G2: WnafGroup, + E::Fr: PrimeFieldBits, +{ + type VerifyingKey = VerifyingKey; + + type ProvingKey = ProvingKey; + + type Proof = Proof; + + fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey) { + let rng = &mut rand::thread_rng(); + let params = + mirage::generate_random_parameters::(SynthInput(&p_data, None), rng).unwrap(); + let v_params = params.vk.clone(); + (ProvingKey(p_data, params), VerifyingKey(v_data, v_params)) + } + + fn prove(pk: &Self::ProvingKey, witness: &FxHashMap) -> Self::Proof { + let rng = &mut rand::thread_rng(); + pk.0.check_all(witness); + Proof(mirage::create_random_proof(SynthInput(&pk.0, Some(witness)), &pk.1, rng).unwrap()) + } + + fn verify(vk: &Self::VerifyingKey, inst: &FxHashMap, pf: &Self::Proof) -> bool { + let pvk = mirage::prepare_verifying_key(&vk.1); + let r1cs_inst_map = vk.0.eval(inst); + let r1cs_inst: Vec = r1cs_inst_map + .into_iter() + .map(|i| int_to_ff(i.i())) + .collect(); + mirage::verify_proof(&pvk, &pf.0, &r1cs_inst).is_ok() + } +} + +#[cfg(test)] +mod test { + use super::*; + use bls12_381::Scalar; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + use std::io::Write; + + #[derive(Clone, Debug)] + struct BlsScalar(Integer); + + impl Arbitrary for BlsScalar { + fn arbitrary(g: &mut Gen) -> Self { + let mut rug_rng = rug::rand::RandState::new_mersenne_twister(); + rug_rng.seed(&Integer::from(u32::arbitrary(g))); + let modulus = Integer::from( + Integer::parse_radix( + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", + 16, + ) + .unwrap(), + ); + let i = Integer::from(modulus.random_below_ref(&mut rug_rng)); + BlsScalar(i) + } + } + + #[quickcheck] + fn int_to_ff_random(BlsScalar(i): BlsScalar) -> bool { + let by_fn = int_to_ff::(i.clone()); + let by_str = Scalar::from_str_vartime(&format!("{i}")).unwrap(); + by_fn == by_str + } + + #[quickcheck] + fn roundtrip_random(BlsScalar(i): BlsScalar) -> bool { + let ff = int_to_ff::(i.clone()); + let i2 = ff_to_int(ff); + i == i2 + } + + fn convert(i: Integer) { + let by_fn = int_to_ff::(i.clone()); + let by_str = Scalar::from_str_vartime(&format!("{i}")).unwrap(); + assert_eq!(by_fn, by_str); + } + + #[test] + fn neg_one() { + let modulus = Integer::from( + Integer::parse_radix( + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", + 16, + ) + .unwrap(), + ); + convert(modulus - 1); + } + + #[test] + fn zero() { + convert(Integer::from(0)); + } + + #[test] + fn one() { + convert(Integer::from(1)); + } + + #[test] + fn parse() { + let path = format!("{}/instance", std::env::temp_dir().to_str().unwrap()); + { + let mut f = File::create(&path).unwrap(); + write!(f, "5\n6").unwrap(); + } + let i = parse_instance::<_, Scalar>(&path); + assert_eq!(i[0], Scalar::from(5)); + assert_eq!(i[1], Scalar::from(6)); + } +} diff --git a/src/target/r1cs/mod.rs b/src/target/r1cs/mod.rs index c9c15ec4..d575fa04 100644 --- a/src/target/r1cs/mod.rs +++ b/src/target/r1cs/mod.rs @@ -2,34 +2,561 @@ use circ_fields::{FieldT, FieldV}; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use log::debug; +use log::{debug, trace}; use paste::paste; use rug::Integer; use serde::{Deserialize, Serialize}; -use std::collections::hash_map::Entry; -use std::fmt::Display; +use std::fmt::Debug; use std::hash::Hash; use crate::ir::term::*; #[cfg(feature = "bellman")] pub mod bellman; +#[cfg(feature = "bellman")] +pub mod mirage; pub mod opt; +pub mod proof; #[cfg(feature = "spartan")] pub mod spartan; pub mod trans; +pub mod wit_comp; #[derive(Debug, Clone, Serialize, Deserialize)] /// A Rank 1 Constraint System. -pub struct R1cs { +/// +/// Extended to comprehend witness commitments and verifier challenges. +/// +/// We view the R1CS relation as R(x, cw_0 .. cw_C, w_0, r_0, w_1, r_1, .. w), where all +/// variables are vectors of field elements and +/// * x is the instance +/// * cw_i is a committed witness +/// * i.e., the commitment is part of the instance, but the data is part of the witness +/// * i from 0 to R is a "round number": +/// * w_i is a witness set by the prover in round i +/// * r_i is a random challenge, sampled as round i ends and round i+1 begins +/// * w is the final round of witnesses +/// +/// ## Operations +/// +/// To interface with a proof system, it must be able to: (mapping to MIRAGE impl) +/// * get all instance variables (create inputs) +/// * get all committed witness vectors (create witnesses, end blocks) +/// * for each round +/// * get the witness variables (create witnesses, end block) +/// * followed by the challenge variables (create challenges) +/// * get all constraints, and create them +/// +/// To interface with a compiler, its must be able to: (mapping to Computation interface) +/// * describe all instance variables in a fixed order (get public variables, fixed order) +/// * describe all committed witness vectors in a fixed order (get witness arrays, fixed order) +/// * for each round +/// * describe the witness variables in that round +/// * (tricky? +/// * since we have deterministic semantics, it suffices to declare the [Computation] +/// witness variables of that round (intermediates are not needed) +/// * ) +/// * describe the challenge variables after that round (immediate) +/// * then, we embed the intermediates in w +/// +/// To interface with an optimizer, it must be able to +/// * build a variable use-site cache +/// * change constraints/remove them +/// * test whether a variable can be eliminated +/// * x cannot +/// * cw_i cannot +/// * r_i cannot +/// * w_i cannot +/// * w can +/// * since only w variable can be eliminated, there is room for optimizating the contents of w_i +/// * For now, we'll assume that putting the computation witness inputs is sufficient +/// +/// Design conclusions: +/// * Since contraints are defined uniformly w.r.t. different kinds of variables, it makes sense +/// for variables to have uniform identifiers. We'll use a [usize]. +/// * The compiler seems capable of meeting a very restricted, stateful builder interface. +/// * The optimizer will be happy as long as +/// * there is a uniform variable representation and +/// * it can test that representation for eliminatability +/// +/// So, our ultimate data structure is: +/// * a next var counter +/// * a (bi) mapping between variable numbers and names +/// * the builder round we're in +/// * indices defining the blocks: +/// * end of x +/// * for each cw_i: end of i +/// * for each round: +/// * end of w_i +/// * end of r_i +/// * no entry for w +/// * constraints! +/// * terms +/// * variables include: +/// * verifier inputs +/// * prover inputs +/// * challenges +/// +/// I'll skip the build interface: it'll map directly to the above. +/// +/// The optimizer won't have an interface. It *will* be allowed to remove variables, leaving unused +/// variable numbers. +/// +/// The proof system interface: +/// * Setup: +/// * get x: names and numbers (numbers needed to interpret LCs) +/// * for i: get cw_i: " +/// * for i: get w_i and r_i: " +/// * get w +/// * Proving: +/// * Details TBD. +/// * Probably: build an evaluator +/// * evaluator: +/// * submit values (inputs, challenges) +/// * get values +pub struct R1cs { modulus: FieldT, - signal_idxs: HashMap, - idxs_signals: HashMap, - next_idx: usize, - public_idxs: HashSet, + idx_to_sig: BiMap, + num_insts: usize, + num_cwits: Vec, + next_cwit: usize, + round_wit_ends: Vec, + next_round_wit: usize, + round_chall_ends: Vec, + next_round_chall: usize, + num_final_wits: usize, + + challenge_names: Vec, + + /// The contraints themselves constraints: Vec<(Lc, Lc, Lc)>, - #[serde(with = "crate::ir::term::serde_mods::vec")] - terms: Vec, + + /// Terms for computing them. + #[serde(with = "crate::ir::term::serde_mods::map")] + terms: HashMap, + precompute: precomp::PreComp, +} + +/// An assembled R1CS relation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct R1csFinal { + field: FieldT, + vars: Vec, + constraints: Vec<(Lc, Lc, Lc)>, + names: HashMap, + // chall_precompute_names: HashMap, +} + +/// A variable +#[derive(Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +#[repr(transparent)] +pub struct Var(usize); + +impl Var { + const NUMBER_BITS: u32 = usize::BITS - 3; + const NUMBER_MASK: usize = !(0b111 << Self::NUMBER_BITS); + fn new(ty: VarType, number: usize) -> Self { + assert!(!Self::NUMBER_MASK & number == 0); + let ty_repr = match ty { + VarType::Inst => 0b000, + VarType::CWit => 0b001, + VarType::RoundWit => 0b010, + VarType::Chall => 0b011, + VarType::FinalWit => 0b100, + }; + Var(ty_repr << Self::NUMBER_BITS | number) + } + fn ty(&self) -> VarType { + match self.0 >> Self::NUMBER_BITS { + 0b000 => VarType::Inst, + 0b001 => VarType::CWit, + 0b010 => VarType::RoundWit, + 0b011 => VarType::Chall, + 0b100 => VarType::FinalWit, + c => panic!("Bad type code {}", c), + } + } + fn number(&self) -> usize { + self.0 & Self::NUMBER_MASK + } +} + +impl std::fmt::Debug for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}({})", self.ty(), self.number()) + } +} + +#[derive(Debug)] +/// A variable type +pub enum VarType { + /// x + Inst, + /// cw_i + CWit, + /// w_i + RoundWit, + /// r_i + Chall, + /// w + FinalWit, +} + +/// Builder interface +impl R1cs { + /// Make an empty constraint system, mod `modulus`. + /// If `values`, then this constraint system will track & expect concrete values. + pub fn new(modulus: FieldT, precompute: precomp::PreComp) -> Self { + R1cs { + modulus, + idx_to_sig: BiMap::new(), + num_insts: Default::default(), + num_cwits: Default::default(), + next_cwit: Default::default(), + round_wit_ends: Default::default(), + next_round_wit: Default::default(), + round_chall_ends: Default::default(), + next_round_chall: Default::default(), + num_final_wits: Default::default(), + challenge_names: Default::default(), + constraints: Vec::new(), + terms: Default::default(), + precompute, + } + } + + fn var(&mut self, s: String, t: Term, ty: VarType) -> Var { + let id = match ty { + VarType::Inst => { + self.num_insts += 1; + self.num_insts - 1 + } + VarType::CWit => { + self.next_cwit += 1; + self.next_cwit - 1 + } + VarType::RoundWit => { + self.next_round_wit += 1; + self.next_round_wit - 1 + } + VarType::Chall => { + self.next_round_chall += 1; + self.next_round_chall - 1 + } + VarType::FinalWit => { + self.num_final_wits += 1; + self.num_final_wits - 1 + } + }; + if let VarType::Chall = ty { + self.challenge_names.push(s.clone()); + } + let var = Var::new(ty, id); + // could check `t` dependents + self.idx_to_sig.insert(var, s); + self.terms.insert(var, t); + var + } + + /// End a round of witnesses and challenges. The challenges will be set after the witnesses. + pub fn end_round(&mut self) { + self.round_wit_ends.push(self.next_round_wit); + self.round_chall_ends.push(self.next_round_chall); + } + + /// Add a (uncommitted) witness variable. + #[track_caller] + pub fn add_var(&mut self, s: String, t: Term, ty: VarType) -> Var { + assert!(!matches!(ty, VarType::CWit)); + self.var(s, t, ty) + } + + /// The total number of variables + pub fn num_vars(&self) -> usize { + self.num_insts + + self.next_cwit + + self.next_round_wit + + self.next_round_chall + + self.num_final_wits + } + + /// Add a vector of committed witness variables + pub fn add_committed_witness(&mut self, names_and_terms: Vec<(String, Term)>) { + let n = names_and_terms.len(); + for (name, value) in names_and_terms { + self.var(name, value, VarType::CWit); + } + self.num_cwits.push(n); + } + + /// Get the zero combination for this system. + pub fn zero(&self) -> Lc { + Lc { + modulus: self.modulus.clone(), + constant: self.modulus.zero(), + monomials: HashMap::default(), + } + } + /// Get a constant constraint for this system. + #[track_caller] + pub fn constant(&self, c: FieldV) -> Lc { + assert_eq!(c.ty(), self.modulus); + Lc { + modulus: self.modulus.clone(), + constant: c, + monomials: HashMap::default(), + } + } + /// Get combination which is just the wire `s`. + pub fn signal_lc(&self, s: &str) -> Lc { + let idx = self + .idx_to_sig + .get_rev(s) + .expect("Missing signal in signal_lc"); + let mut lc = self.zero(); + lc.monomials.insert(*idx, self.modulus.new_v(1)); + lc + } + /// Make `a * b = c` a constraint. + pub fn constraint(&mut self, a: Lc, b: Lc, c: Lc) { + assert_eq!(&self.modulus, &a.modulus); + assert_eq!(&self.modulus, &b.modulus); + assert_eq!(&self.modulus, &c.modulus); + debug!( + "Constraint:\n {}\n * {}\n = {}", + self.format_lc(&a), + self.format_lc(&b), + self.format_lc(&c) + ); + self.constraints.push((a, b, c)); + } + + /// Get a nice string represenation of the combination `a`. + pub fn format_lc(&self, a: &Lc) -> String { + let mut s = String::new(); + + let half_m: Integer = self.modulus().clone() / 2; + let abs = |i: Integer| { + if i <= half_m { + i + } else { + self.modulus() - i + } + }; + let sign = |i: &Integer| if i < &half_m { "+" } else { "-" }; + let format_i = |i: &FieldV| { + let ii: Integer = i.into(); + format!("{}{}", sign(&ii), abs(ii)) + }; + + s.push_str(&format_i(&a.constant)); + for (idx, coeff) in &a.monomials { + s.extend( + format!( + " {} {}", + self.idx_to_sig.get_fwd(idx).unwrap(), + format_i(coeff), + ) + .chars(), + ); + } + s + } + + /// Can this variable be eliminated? + pub fn can_eliminate(&self, var: Var) -> bool { + matches!(var.ty(), VarType::RoundWit) + } + + /// Get a nice string represenation of the tuple. + pub fn format_qeq(&self, (a, b, c): &(Lc, Lc, Lc)) -> String { + format!( + "({})({}) = {}", + self.format_lc(a), + self.format_lc(b), + self.format_lc(c) + ) + } + + fn modulus(&self) -> &Integer { + self.modulus.modulus() + } + + /// Access the raw constraints. + pub fn constraints(&self) -> &Vec<(Lc, Lc, Lc)> { + &self.constraints + } +} + +impl R1csFinal { + /// Check `a * b = c` in this constraint system. + pub fn check(&self, a: &Lc, b: &Lc, c: &Lc, values: &HashMap) { + let av = self.eval(a, values); + let bv = self.eval(b, values); + let cv = self.eval(c, values); + if (av.clone() * &bv) != cv { + panic!( + "Error! Bad constraint:\n {} (value {})\n * {} (value {})\n = {} (value {})", + self.format_lc(a), + av, + self.format_lc(b), + bv, + self.format_lc(c), + cv + ) + } + } + + /// Get a nice string represenation of the combination `a`. + fn format_lc(&self, a: &Lc) -> String { + let mut s = String::new(); + + let half_m: Integer = self.field.modulus().clone() / 2; + let abs = |i: Integer| { + if i <= half_m { + i + } else { + self.field.modulus() - i + } + }; + let sign = |i: &Integer| if i < &half_m { "+" } else { "-" }; + let format_i = |i: &FieldV| { + let ii: Integer = i.into(); + format!("{}{}", sign(&ii), abs(ii)) + }; + + s.push_str(&format_i(&a.constant)); + for (idx, coeff) in &a.monomials { + s.extend(format!(" {} {}", self.names.get(idx).unwrap(), format_i(coeff),).chars()); + } + s + } + + fn eval(&self, lc: &Lc, values: &HashMap) -> FieldV { + let mut acc = lc.constant.clone(); + for (var, coeff) in &lc.monomials { + let val = values + .get(var) + .unwrap_or_else(|| panic!("Missing value in R1cs::eval for variable {:?}", var)) + .clone(); + acc += val * coeff; + } + acc + } + + /// Check all assertions + fn check_all(&self, values: &HashMap) { + for (a, b, c) in &self.constraints { + self.check(a, b, c, values) + } + } +} + +impl ProverData { + /// Check all assertions. Puts in 1 for challenges. + pub fn check_all(&self, values: &HashMap) { + // we need to evaluate all R1CS variables + let mut var_values: HashMap = Default::default(); + let mut eval = wit_comp::StagedWitCompEvaluator::new(&self.precompute); + // this will hold inputs to the multi-round evaluator. + let mut inputs = values.clone(); + while var_values.len() < self.r1cs.vars.len() { + trace!( + "Have {}/{} values, doing another round", + var_values.len(), + self.r1cs.vars.len() + ); + // do a round of evaluation + let value_vec = eval.eval_stage(std::mem::take(&mut inputs)); + for value in value_vec { + var_values.insert(self.r1cs.vars[var_values.len()], value.as_pf().clone()); + } + // fill the challenges with 1s + if var_values.len() < self.r1cs.vars.len() { + for next_var_i in var_values.len()..self.r1cs.vars.len() { + if !matches!(self.r1cs.vars[next_var_i].ty(), VarType::Chall) { + break; + } + let var = self.r1cs.vars[next_var_i]; + let name = self.r1cs.names.get(&var).unwrap().clone(); + let val = self.r1cs.field.new_v(1); + var_values.insert(var, val.clone()); + inputs.insert(name, Value::Field(val)); + } + } + } + self.r1cs.check_all(&var_values); + } +} + +/// A bidirectional map. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct BiMap { + fwd: HashMap, + rev: HashMap, +} + +#[allow(dead_code)] +impl BiMap { + fn new() -> Self { + Self { + fwd: Default::default(), + rev: Default::default(), + } + } + fn len(&self) -> usize { + debug_assert_eq!(self.fwd.len(), self.rev.len()); + self.fwd.len() + } + #[allow(clippy::uninlined_format_args)] + fn insert(&mut self, s: S, t: T) { + assert!( + self.fwd.insert(s.clone(), t.clone()).is_none(), + "Duplicate key {:?}", + s + ); + assert!( + self.rev.insert(t.clone(), s).is_none(), + "Duplicate value {:?}", + t + ); + } + fn contains_key(&self, s: &Q) -> bool + where + S: std::borrow::Borrow, + Q: Hash + Eq + ?Sized, + { + self.fwd.contains_key(s) + } + fn get_fwd(&self, s: &Q) -> Option<&T> + where + S: std::borrow::Borrow, + Q: Hash + Eq + ?Sized, + { + self.fwd.get(s) + } + fn get_rev(&self, t: &Q) -> Option<&S> + where + T: std::borrow::Borrow, + Q: Hash + Eq + ?Sized, + { + self.rev.get(t) + } + fn remove_fwd>(&mut self, s: &Q) { + let t = self.fwd.remove(s.borrow()).unwrap(); + self.rev.remove(&t).unwrap(); + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +/// The type of a signal +pub enum SigTy { + /// Known by all parties, initially + Instance, + /// Known by the prover + Witness, + /// Randomly sampled + Challenge, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -37,7 +564,7 @@ pub struct R1cs { pub struct Lc { modulus: FieldT, constant: FieldV, - monomials: HashMap, + monomials: HashMap, } impl Lc { @@ -87,13 +614,13 @@ macro_rules! arith_impl { } for (i, v) in &other.monomials { match self.monomials.entry(*i) { - Entry::Occupied(mut e) => { + std::collections::hash_map::Entry::Occupied(mut e) => { e.get_mut().[<$fn _assign>](v); if e.get().is_zero() { e.remove_entry(); } } - Entry::Vacant(e) => { + std::collections::hash_map::Entry::Vacant(e) => { let mut m = self.modulus.zero(); m.[<$fn _assign>](v); e.insert(m); @@ -197,135 +724,9 @@ impl MulAssign for Lc { } } -impl R1cs { - /// Make an empty constraint system, mod `modulus`. - /// If `values`, then this constraint system will track & expect concrete values. - pub fn new(modulus: FieldT) -> Self { - R1cs { - modulus, - signal_idxs: HashMap::default(), - idxs_signals: HashMap::default(), - next_idx: 0, - public_idxs: HashSet::default(), - constraints: Vec::new(), - terms: Vec::new(), - } - } - /// Get the zero combination for this system. - pub fn zero(&self) -> Lc { - Lc { - modulus: self.modulus.clone(), - constant: self.modulus.zero(), - monomials: HashMap::default(), - } - } - /// Get a constant constraint for this system. - #[track_caller] - pub fn constant(&self, c: FieldV) -> Lc { - assert_eq!(c.ty(), self.modulus); - Lc { - modulus: self.modulus.clone(), - constant: c, - monomials: HashMap::default(), - } - } - /// Get combination which is just the wire `s`. - pub fn signal_lc(&self, s: &S) -> Lc { - let idx = self - .signal_idxs - .get(s) - .expect("Missing signal in signal_lc"); - let mut lc = self.zero(); - lc.monomials.insert(*idx, self.modulus.new_v(1)); - lc - } - /// Create a new wire, `s`. If this system is tracking concrete values, you must provide the - /// value, `v`. - /// - /// You must also provide `term`, that computes the signal value from *some* inputs. - pub fn add_signal(&mut self, s: S, term: Term) { - let n = self.next_idx; - self.next_idx += 1; - self.signal_idxs.insert(s.clone(), n); - self.idxs_signals.insert(n, s); - assert_eq!(n, self.terms.len()); - self.terms.push(term); - } - /// Make `s` a public wire in the system - pub fn publicize(&mut self, s: &S) { - self.signal_idxs - .get(s) - .cloned() - .map(|i| self.public_idxs.insert(i)); - } - /// Make `a * b = c` a constraint. - pub fn constraint(&mut self, a: Lc, b: Lc, c: Lc) { - assert_eq!(&self.modulus, &a.modulus); - assert_eq!(&self.modulus, &b.modulus); - assert_eq!(&self.modulus, &c.modulus); - debug!( - "Constraint:\n {}\n * {}\n = {}", - self.format_lc(&a), - self.format_lc(&b), - self.format_lc(&c) - ); - self.constraints.push((a, b, c)); - } - /// Get a nice string represenation of the combination `a`. - pub fn format_lc(&self, a: &Lc) -> String { - let mut s = String::new(); - - let half_m: Integer = self.modulus().clone() / 2; - let abs = |i: Integer| { - if i <= half_m { - i - } else { - self.modulus() - i - } - }; - let sign = |i: &Integer| if i < &half_m { "+" } else { "-" }; - let format_i = |i: &FieldV| { - let ii: Integer = i.into(); - format!("{}{}", sign(&ii), abs(ii)) - }; - - s.push_str(&format_i(&a.constant)); - for (idx, coeff) in &a.monomials { - s.extend( - format!( - " {} {}", - self.idxs_signals.get(idx).unwrap(), - format_i(coeff), - ) - .chars(), - ); - } - s - } - - /// Get a nice string represenation of the tuple. - pub fn format_qeq(&self, (a, b, c): &(Lc, Lc, Lc)) -> String { - format!( - "({})({}) = {}", - self.format_lc(a), - self.format_lc(b), - self.format_lc(c) - ) - } - - fn modulus(&self) -> &Integer { - self.modulus.modulus() - } - - /// Access the raw constraints. - pub fn constraints(&self) -> &Vec<(Lc, Lc, Lc)> { - &self.constraints - } -} - -impl R1cs { +impl R1cs { /// Check `a * b = c` in this constraint system. - pub fn check(&self, a: &Lc, b: &Lc, c: &Lc, values: &HashMap) { + pub fn check(&self, a: &Lc, b: &Lc, c: &Lc, values: &HashMap) { let av = self.eval(a, values); let bv = self.eval(b, values); let cv = self.eval(c, values); @@ -342,105 +743,222 @@ impl R1cs { } } - fn eval(&self, lc: &Lc, values: &HashMap) -> FieldV { + fn eval(&self, lc: &Lc, values: &HashMap) -> FieldV { let mut acc = lc.constant.clone(); for (var, coeff) in &lc.monomials { - let name = self.idxs_signals.get(var).unwrap(); let val = values - .get(name) - .unwrap_or_else(|| panic!("Missing value in R1cs::eval for variable {}", name)) - .as_pf() + .get(var) + .unwrap_or_else(|| panic!("Missing value in R1cs::eval for variable {:?}", var)) .clone(); acc += val * coeff; } acc } - /// Check all assertions, if values are being tracked. - pub fn check_all(&self, values: &HashMap) { - for (a, b, c) in &self.constraints { - self.check(a, b, c, values) - } - } - - /// Add the signals of this R1CS instance to the precomputation. - fn extend_precomputation(&self, precompute: &mut precomp::PreComp, public_signals_only: bool) { - for i in 0..self.next_idx { - let sig_name = self.idxs_signals.get(&i).unwrap(); - if (!public_signals_only || self.public_idxs.contains(&i)) - && !precompute.outputs().contains_key(sig_name) - { - let term = self.terms[i].clone(); - precompute.add_output(sig_name.clone(), term); - } - } - } - - /// Compute the verifier data for this R1CS relation, given a precomputation - /// that computes the variables that are relation inputs - pub fn verifier_data(&self, cs: &Computation) -> VerifierData { - let mut precompute = cs.precomputes.clone(); - self.extend_precomputation(&mut precompute, true); - let public_inputs = cs.metadata.get_inputs_for_party(None); - precompute.restrict_to_inputs(public_inputs); - let pf_input_order: Vec = (0..self.next_idx) - .filter(|i| self.public_idxs.contains(i)) - .map(|i| self.idxs_signals.get(&i).cloned().unwrap()) - .collect(); - let mut precompute_inputs = HashMap::default(); - for input in &pf_input_order { - if let Some(output_term) = precompute.outputs().get(input) { - for (v, s) in extras::free_variables_with_sorts(output_term.clone()) { - precompute_inputs.insert(v, s); + fn eval_all_vars(&self, inputs: &HashMap) -> HashMap { + let after_precompute = self.precompute.eval(inputs); + let mut cache = Default::default(); + self.terms + .iter() + .map(|(var, term)| { + let val = eval_cached(term, &after_precompute, &mut cache); + if let Value::Field(f) = val { + (*var, f.clone()) + } else { + panic!("Non-field"); } - } else { - precompute_inputs.insert(input.clone(), Sort::Field(self.modulus.clone())); - } - } - VerifierData { - precompute_inputs, - precompute, - pf_input_order, + }) + .collect() + } + + /// Check all assertions, if values are being tracked. + pub fn check_all(&self, inputs: &HashMap) { + let var_values = self.eval_all_vars(inputs); + for (a, b, c) in &self.constraints { + self.check(a, b, c, &var_values) } } - /// Compute the verifier data for this R1CS relation, given a precomputation - /// that computes the variables that are relation inputs - pub fn prover_data(self, cs: &Computation) -> ProverData { + fn insts_iter(&self) -> impl Iterator + '_ { + (0..self.num_insts) + .map(|i| Var::new(VarType::Inst, i)) + .filter(move |v| self.idx_to_sig.contains_key(v)) + } + + fn final_wits_iter(&self) -> impl Iterator + '_ { + (0..self.num_final_wits) + .map(|i| Var::new(VarType::FinalWit, i)) + .filter(move |v| self.idx_to_sig.contains_key(v)) + } + + fn cwits_iter(&self) -> impl Iterator + '_ { + (0..self.next_cwit) + .map(|i| Var::new(VarType::CWit, i)) + .filter(move |v| self.idx_to_sig.contains_key(v)) + } + + fn challs_iter(&self, round: usize) -> impl Iterator + '_ { + let start = if round == 0 { + 0 + } else { + self.round_chall_ends[round - 1] + }; + let end = self.round_chall_ends[round]; + (start..end) + .map(|i| Var::new(VarType::Chall, i)) + .filter(move |v| self.idx_to_sig.contains_key(v)) + } + + fn round_wits_iter(&self, round: usize) -> impl Iterator + '_ { + let start = if round == 0 { + 0 + } else { + self.round_wit_ends[round - 1] + }; + let end = self.round_wit_ends[round]; + (start..end) + .map(|i| Var::new(VarType::RoundWit, i)) + .filter(move |v| self.idx_to_sig.contains_key(v)) + } + + /// Returns a list of (signal list, challenge list) pairs. + /// The prove computes the values of signals. + /// The proof system computes the values of challenges. + /// All signals are computed from (a) prover inputs and (b) challenge values. + fn stage_vars(&self) -> Vec<(Vec, Vec)> { + let mut out = Vec::new(); + out.push(( + self.insts_iter().chain(self.cwits_iter()).collect(), + Vec::new(), + )); + for round_idx in 0..self.round_chall_ends.len() { + out.push(( + self.round_wits_iter(round_idx).collect(), + self.challs_iter(round_idx).collect(), + )); + } + out.push((self.final_wits_iter().collect(), Vec::new())); + out + } + + /// Prover Data + fn prover_data(self, cs: &Computation) -> ProverData { let mut precompute = cs.precomputes.clone(); self.extend_precomputation(&mut precompute, false); // we still need to remove the non-r1cs variables use crate::ir::proof::PROVER_ID; let all_inputs = cs.metadata.get_inputs_for_party(Some(PROVER_ID)); precompute.restrict_to_inputs(all_inputs); - let pf_input_order: Vec = (0..self.next_idx) - .filter(|i| self.public_idxs.contains(i)) - .map(|i| self.idxs_signals.get(&i).cloned().unwrap()) - .collect(); - let mut precompute_inputs = HashMap::default(); - for input in &pf_input_order { - if let Some(output_term) = precompute.outputs().get(input) { - for (v, s) in extras::free_variables_with_sorts(output_term.clone()) { - precompute_inputs.insert(v, s); - } - } else { - precompute_inputs.insert(input.clone(), Sort::Field(self.modulus.clone())); - } + let mut vars: HashMap = { + PostOrderIter::new(precompute.tuple()) + .filter_map(|t| { + if let Op::Var(n, s) = t.op() { + Some((n.clone(), s.clone())) + } else { + None + } + }) + .collect() + }; + for c in &self.challenge_names { + vars.remove(c); } - for o in precompute.outputs().keys() { - precompute_inputs.remove(o); + let mut precompute_map = precompute.flatten(); + let mut comp = wit_comp::StagedWitComp::default(); + let mut var_sequence = Vec::new(); + for (computed_in_stage, challs) in self.stage_vars() { + let terms = computed_in_stage + .iter() + .map(|v| { + let name = self.idx_to_sig.get_fwd(v).unwrap(); + precompute_map.remove(name).unwrap() + }) + .collect(); + comp.add_stage(std::mem::take(&mut vars), terms); + vars = challs + .iter() + .map(|cvar| { + ( + self.idx_to_sig.get_fwd(cvar).unwrap().clone(), + Sort::Field(self.modulus.clone()), + ) + }) + .collect(); + var_sequence.extend(computed_in_stage); + var_sequence.extend(challs); } ProverData { - precompute_inputs, - precompute, - r1cs: self, + r1cs: R1csFinal { + field: self.modulus.clone(), + names: var_sequence + .iter() + .map(|v| (*v, self.idx_to_sig.get_fwd(v).unwrap().clone())) + .collect(), + vars: var_sequence, + constraints: self.constraints, + }, + precompute: comp, } } + /// Prover Data + fn verifier_data(&self, cs: &Computation) -> VerifierData { + let mut precompute = cs.precomputes.clone(); + self.extend_precomputation(&mut precompute, true); + let public_inputs = cs.metadata.get_inputs_for_party(None); + precompute.restrict_to_inputs(public_inputs); + let vars: HashMap = { + PostOrderIter::new(precompute.tuple()) + .filter_map(|t| { + if let Op::Var(n, s) = t.op() { + Some((n.clone(), s.clone())) + } else { + None + } + }) + .collect() + }; + for c in &self.challenge_names { + assert!(!vars.contains_key(c)); + } + let mut precompute_map = precompute.flatten(); + let terms = self + .insts_iter() + .map(|v| { + let name = self.idx_to_sig.get_fwd(&v).unwrap(); + precompute_map.remove(name).unwrap() + }) + .collect(); + let mut comp = wit_comp::StagedWitComp::default(); + comp.add_stage(vars, terms); + VerifierData { precompute: comp } + } + + /// Add the signals of this R1CS instance to the precomputation. + fn extend_precomputation(&self, precompute: &mut precomp::PreComp, public_signals_only: bool) { + for (var, term) in &self.terms { + if !matches!(var.ty(), VarType::Chall) + && (!public_signals_only || matches!(var.ty(), VarType::Inst)) + { + let sig_name = self.idx_to_sig.get_fwd(var).unwrap(); + if !precompute.outputs().contains_key(sig_name) { + precompute.add_output(sig_name.clone(), term.clone()); + } + } + } + } + + /// Split this R1CS into prover (Proving, Setup) and verifier (Verifying) information. + pub fn finalize(self, cs: &Computation) -> (ProverData, VerifierData) { + let vd = self.verifier_data(cs); + let pd = self.prover_data(cs); + (pd, vd) + } + /// Get an IR term that represents this system. pub fn lc_ir_term(&self, lc: &Lc) -> Term { term(PF_ADD, - std::iter::once(pf_lit(lc.constant.clone())).chain(lc.monomials.iter().map(|(i, coeff)| term![PF_MUL; pf_lit(coeff.clone()), leaf_term(Op::Var(self.idxs_signals.get(i).unwrap().into(), Sort::Field(self.modulus.clone())))])).collect()) + std::iter::once(pf_lit(lc.constant.clone())).chain(lc.monomials.iter().map(|(i, coeff)| term![PF_MUL; pf_lit(coeff.clone()), leaf_term(Op::Var(self.idx_to_sig.get_fwd(i).unwrap().into(), Sort::Field(self.modulus.clone())))])).collect()) } /// Get an IR term that represents this system. @@ -451,53 +969,31 @@ impl R1cs { } } -/// Relation-related data that a verifier needs to check a proof. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VerifierData { - /// Inputs that the verifier must have - pub precompute_inputs: HashMap, - /// A precomputation to perform on those inputs - pub precompute: precomp::PreComp, - /// The order in which the outputs must be fed into the proof system - pub pf_input_order: Vec, -} - impl VerifierData { - /// Given verifier inputs, compute a vector of integers to feed to the proof system. - pub fn eval(&self, value_map: &HashMap) -> Vec { - for (input, sort) in &self.precompute_inputs { - let value = value_map - .get(input) - .unwrap_or_else(|| panic!("No input for {}", input)); - let sort2 = value.sort(); - assert_eq!( - sort, &sort2, - "Sort mismatch for {input}. Expected\n\t{sort} but got\n\t{sort2}", - ); - } - let new_map = self.precompute.eval(value_map); - self.pf_input_order - .iter() - .map(|input| { - new_map - .get(input) - .unwrap_or_else(|| panic!("Missing input {}", input)) - .as_pf() - .i() - }) + /// Given verifier inputs, compute a vector of field values to feed to the proof system. + pub fn eval(&self, value_map: &HashMap) -> Vec { + let mut eval = wit_comp::StagedWitCompEvaluator::new(&self.precompute); + eval.eval_stage(value_map.clone()) + .into_iter() + .map(|v| v.as_pf().clone()) .collect() } } -/// Relation-related data that a prover needs to check a proof. -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Relation-related data that a prover needs to make a proof. +#[derive(Debug, Serialize, Deserialize)] pub struct ProverData { - /// The R1CS instance. - pub r1cs: R1cs, - /// Inputs that the verifier must have - pub precompute_inputs: HashMap, - /// A precomputation to perform on those inputs - pub precompute: precomp::PreComp, + /// R1cs + pub r1cs: R1csFinal, + /// Witness computation + pub precompute: wit_comp::StagedWitComp, +} + +/// Relation-related data that a verifier needs to check a proof. +#[derive(Debug, Serialize, Deserialize)] +pub struct VerifierData { + /// Instance computation + pub precompute: wit_comp::StagedWitComp, } #[derive(Clone, Debug)] diff --git a/src/target/r1cs/opt.rs b/src/target/r1cs/opt.rs index ca5069ac..d053298d 100644 --- a/src/target/r1cs/opt.rs +++ b/src/target/r1cs/opt.rs @@ -1,21 +1,24 @@ //! Optimizations over R1CS -use super::*; -use crate::cfg::CircCfg; -use crate::util::once::OnceQueue; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use log::debug; -struct LinReducer { - r1cs: R1cs, - uses: HashMap>, +use std::collections::hash_map::Entry; + +use super::*; +use crate::cfg::CircCfg; +use crate::util::once::OnceQueue; + +struct LinReducer { + r1cs: R1cs, + uses: HashMap>, queue: OnceQueue, /// The maximum size LC (number of non-constant monomials) /// that will be used for propagation lc_size_thresh: usize, } -impl LinReducer { - fn new(mut r1cs: R1cs, lc_size_thresh: usize) -> Self { +impl LinReducer { + fn new(mut r1cs: R1cs, lc_size_thresh: usize) -> Self { let uses = LinReducer::gen_uses(&r1cs); let queue = (0..r1cs.constraints.len()).collect::>(); for c in &mut r1cs.constraints { @@ -30,9 +33,9 @@ impl LinReducer { } // generate a new uses hash - fn gen_uses(r1cs: &R1cs) -> HashMap> { - let mut uses: HashMap> = - HashMap::with_capacity_and_hasher(r1cs.next_idx, Default::default()); + fn gen_uses(r1cs: &R1cs) -> HashMap> { + let mut uses: HashMap> = + HashMap::with_capacity_and_hasher(r1cs.num_vars(), Default::default()); let mut add = |i: usize, y: &Lc| { for x in y.monomials.keys() { uses.get_mut(x).map(|m| m.insert(i)).or_else(|| { @@ -54,7 +57,7 @@ impl LinReducer { /// Substitute `val` for `var` in constraint with id `con_id`. /// Updates uses conservatively (not precisely) /// Returns whether a sub happened. - fn sub_in(&mut self, var: usize, val: &Lc, con_id: usize) -> bool { + fn sub_in(&mut self, var: Var, val: &Lc, con_id: usize) -> bool { let (a, b, c) = &mut self.r1cs.constraints[con_id]; let uses = &mut self.uses; let mut do_in = |a: &mut Lc| { @@ -112,15 +115,13 @@ impl LinReducer { self.r1cs.constraints[i].2.clear(); } - fn run(mut self) -> R1cs { + fn run(mut self) -> R1cs { while let Some(con_id) = self.queue.pop() { - if let Some((var, lc)) = - as_linear_sub(&self.r1cs.constraints[con_id], &self.r1cs.public_idxs) - { + if let Some((var, lc)) = as_linear_sub(&self.r1cs.constraints[con_id], &self.r1cs) { if lc.monomials.len() < self.lc_size_thresh { debug!( "Elim: {} -> {}", - self.r1cs.idxs_signals.get(&var).unwrap(), + self.r1cs.idx_to_sig.get_fwd(&var).unwrap(), self.r1cs.format_lc(&lc) ); self.clear_constraint(con_id); @@ -141,10 +142,10 @@ impl LinReducer { } } -fn as_linear_sub((a, b, c): &(Lc, Lc, Lc), public: &HashSet) -> Option<(usize, Lc)> { +fn as_linear_sub((a, b, c): &(Lc, Lc, Lc), r1cs: &R1cs) -> Option<(Var, Lc)> { if a.is_zero() || b.is_zero() { for i in c.monomials.keys() { - if !public.contains(i) { + if r1cs.can_eliminate(*i) { let mut lc = c.clone(); let v = lc.monomials.remove(i).unwrap(); lc *= v.recip(); @@ -184,7 +185,7 @@ fn constantly_true((a, b, c): &(Lc, Lc, Lc)) -> bool { /// /// * `lc_size_thresh`: the maximum size LC (number of non-constant monomials) that will be used /// for propagation. `None` means no size limit. -pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs { +pub fn reduce_linearities(r1cs: R1cs, cfg: &CircCfg) -> R1cs { LinReducer::new(r1cs, cfg.r1cs.lc_elim_thresh).run() } @@ -199,7 +200,7 @@ mod test { use rand::SeedableRng; #[derive(Clone, Debug)] - pub struct SatR1cs(R1cs, FxHashMap); + pub struct SatR1cs(R1cs, FxHashMap); impl Arbitrary for SatR1cs { fn arbitrary(g: &mut Gen) -> Self { @@ -208,14 +209,18 @@ mod test { let n_vars = g.size() + 1; let vars: Vec<_> = (0..n_vars).map(|i| format!("v{i}")).collect(); let mut values: FxHashMap = Default::default(); - let mut r1cs = R1cs::new(field.clone()); + let mut var_values: FxHashMap = Default::default(); + let mut r1cs = R1cs::new(field.clone(), Default::default()); let mut rng = rand::rngs::StdRng::seed_from_u64(u64::arbitrary(g)); for v in &vars { - values.insert(v.clone(), Value::Field(field.random_v(&mut rng))); - r1cs.add_signal( + let var = r1cs.add_var( v.clone(), leaf_term(Op::Var(v.clone(), Sort::Field(field.clone()))), + VarType::FinalWit, ); + let val = field.random_v(&mut rng); + var_values.insert(var, val.clone()); + values.insert(v.into(), Value::Field(val)); } for _ in 0..(2 * g.size()) { let ac: isize = ::arbitrary(g) % m; @@ -236,7 +241,8 @@ mod test { } else { r1cs.zero() } + cc; - let off = r1cs.eval(&a, &values) * r1cs.eval(&b, &values) - r1cs.eval(&c, &values); + let off = r1cs.eval(&a, &var_values) * r1cs.eval(&b, &var_values) + - r1cs.eval(&c, &var_values); c += &off; r1cs.constraint(a, b, c); } diff --git a/src/target/r1cs/proof.rs b/src/target/r1cs/proof.rs new file mode 100644 index 00000000..e5b36092 --- /dev/null +++ b/src/target/r1cs/proof.rs @@ -0,0 +1,331 @@ +//! A trait for CirC-compatible proofs + +use std::fs::File; +use std::path::Path; + +use bincode::{deserialize_from, serialize_into}; +use fxhash::FxHashMap as HashMap; +use serde::{Deserialize, Serialize}; + +use super::{ProverData, VerifierData}; +use crate::ir::term::text::parse_value_map; +use crate::ir::term::Value; + +/// A trait for CirC-compatible proofs +pub trait ProofSystem { + /// A verifying key + type VerifyingKey: Serialize + for<'a> Deserialize<'a>; + /// A proving key + type ProvingKey: Serialize + for<'a> Deserialize<'a>; + /// A proof + type Proof: Serialize + for<'a> Deserialize<'a>; + /// Setup + fn setup(p_data: ProverData, v_data: VerifierData) -> (Self::ProvingKey, Self::VerifyingKey); + /// Proving + fn prove(pk: &Self::ProvingKey, witness: &HashMap) -> Self::Proof; + /// Verification + fn verify(vk: &Self::VerifyingKey, inst: &HashMap, pf: &Self::Proof) -> bool; + + /// Setup to files + fn setup_fs( + p_data: ProverData, + v_data: VerifierData, + pk_path: P1, + vk_path: P2, + ) -> std::io::Result<()> + where + P1: AsRef, + P2: AsRef, + { + let (pk, vk) = Self::setup(p_data, v_data); + let mut pk_file = File::create(pk_path)?; + let mut vk_file = File::create(vk_path)?; + serialize_into(&mut pk_file, &pk).unwrap(); + serialize_into(&mut vk_file, &vk).unwrap(); + Ok(()) + } + /// Prove to/from files + fn prove_fs(pk_path: P1, witness_path: P2, pf_path: P2) -> std::io::Result<()> + where + P1: AsRef, + P2: AsRef, + { + let pk_file = File::open(pk_path)?; + let mut pf_file = File::create(pf_path)?; + let witness_bytes = std::fs::read(witness_path)?; + let witness = parse_value_map(&witness_bytes); + let pk: Self::ProvingKey = deserialize_from(pk_file).unwrap(); + let pf = Self::prove(&pk, &witness); + serialize_into(&mut pf_file, &pf).unwrap(); + Ok(()) + } + /// Verify from files + fn verify_fs(vk_path: P1, instance_path: P2, pf_path: P2) -> std::io::Result + where + P1: AsRef, + P2: AsRef, + { + let vk_file = File::open(vk_path)?; + let pf_file = File::open(pf_path)?; + let instance_bytes = std::fs::read(instance_path)?; + let instance = parse_value_map(&instance_bytes); + let vk: Self::VerifyingKey = deserialize_from(vk_file).unwrap(); + let pf: Self::Proof = deserialize_from(pf_file).unwrap(); + Ok(Self::verify(&vk, &instance, &pf)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::cfg::CircCfg; + use crate::ir::term::*; + use crate::target::r1cs; + + #[allow(dead_code)] + fn test_setup_prove_verify( + cs: Computation, + p_input: HashMap, + v_input: HashMap, + ) { + let cfg = CircCfg::default(); + let r1cs = r1cs::trans::to_r1cs(&cs, &cfg); + let (p_data, v_data) = r1cs.finalize(&cs); + let (pk, vk) = PS::setup(p_data, v_data); + let pf = PS::prove(&pk, &p_input); + assert!(PS::verify(&vk, &v_input, &pf)); + } + + #[cfg(feature = "bellman")] + mod mirage { + use super::super::super::mirage::Mirage; + use super::*; + + #[test] + fn bool_np() { + let c = text::parse_computation( + b" + (computation + (metadata + (parties P) + (inputs (a bool (party 0)) (b bool (party 0)) (return bool)) + ) + (precompute + ((a bool) (b bool)) + ((return bool)) + (tuple (and a b)) + ) + (= (and a b) return) + )", + ); + let p_input = text::parse_value_map( + b" + (let ( + (a true) + (b true) + ) false; ignored + )", + ); + let v_input = text::parse_value_map( + b" + (let ( + (return true) + ) false; ignored + )", + ); + test_setup_prove_verify::>(c, p_input, v_input); + } + + #[test] + fn rand_perm() { + let c = text::parse_computation( + b" + (computation + (metadata + (parties P) + (inputs + (a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + ) + ) + (precompute () () (tuple)) + (= + (* (+ a0 c) (+ a1 c) (+ a2 c)) + (* (+ b0 c) (+ b1 c) (+ b2 c)) + ) + )", + ); + let p_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + (a0 #f1) + (a1 #f-1) + (a2 #f4) + (b0 #f-1) + (b1 #f1) + (b2 #f4) + ) false))"); + let v_input = text::parse_value_map( + b" + (let ( + ) false; ignored + )", + ); + test_setup_prove_verify::>(c, p_input, v_input); + } + + #[test] + fn rand_double_perm() { + let c = text::parse_computation( + b" + (computation + (metadata + (parties P) + (inputs + (a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + (d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + ) + ) + (precompute () () (tuple)) + (and + (= + (* (+ a0 c) (+ a1 c) (+ a2 c)) + (* (+ b0 c) (+ b1 c) (+ b2 c)) + ) + (= + (* (+ a0 d) (+ a1 d) (+ a2 d)) + (* (+ b0 d) (+ b1 d) (+ b2 d)) + ) + ) + )", + ); + let p_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + (a0 #f1) + (a1 #f-1) + (a2 #f4) + (b0 #f-1) + (b1 #f1) + (b2 #f4) + ) false))"); + let v_input = text::parse_value_map( + b" + (let ( + ) false; ignored + )", + ); + test_setup_prove_verify::>(c, p_input, v_input); + } + + #[test] + fn rand_double_perm_inst() { + let c = text::parse_computation( + b" + (computation + (metadata + (parties P) + (inputs + (a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513)) + (a1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513)) + (a2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513)) + (b0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b1 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (b2 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (c (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + (d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + ) + ) + (precompute () () (tuple)) + (and + (= + (* (+ a0 c) (+ a1 c) (+ a2 c)) + (* (+ b0 c) (+ b1 c) (+ b2 c)) + ) + (= + (* (+ a0 d) (+ a1 d) (+ a2 d)) + (* (+ b0 d) (+ b1 d) (+ b2 d)) + ) + ) + )", + ); + let p_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + (a0 #f1) + (a1 #f-1) + (a2 #f4) + (b0 #f-1) + (b1 #f1) + (b2 #f4) + ) false))"); + let v_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + (a0 #f1) + (a1 #f-1) + (a2 #f4) + ) false; ignored + ))", + ); + test_setup_prove_verify::>(c, p_input, v_input); + } + + #[test] + fn precomp_with_chall() { + let c = text::parse_computation( + b" + (computation + (metadata + (parties P) + (inputs + (a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0)) + (ha (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (party 0) (round 1)) + (d (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513) (random)) + ) + ) + (precompute ( + (a0 (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513)) + ) ( + (ha (mod 52435875175126190479447740508185965837690552500527637822603658699938581184513)) + ) (tuple + (* a0 d) + )) + (= + ha + (* a0 d) + ) + )", + ); + let p_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + (a0 #f1) + ) false))"); + let v_input = text::parse_value_map( + b" + (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 + (let ( + ) false; ignored + ))", + ); + test_setup_prove_verify::>(c, p_input, v_input); + } + } +} diff --git a/src/target/r1cs/spartan.rs b/src/target/r1cs/spartan.rs index 53de2088..5d25b7b6 100644 --- a/src/target/r1cs/spartan.rs +++ b/src/target/r1cs/spartan.rs @@ -1,4 +1,5 @@ //! Export circ R1cs to Spartan +use crate::target::r1cs::wit_comp::StagedWitCompEvaluator; use crate::target::r1cs::*; use bincode::{deserialize_from, serialize_into}; use curve25519_dalek::scalar::Scalar; @@ -55,7 +56,7 @@ pub fn verify>( let mut inp = Vec::new(); for v in &values { - let scalar = int_to_scalar(v); + let scalar = int_to_scalar(&v.i()); inp.push(scalar.to_bytes()); } let inputs = InputsAssignment::new(&inp).unwrap(); @@ -78,11 +79,11 @@ pub fn r1cs_to_spartan( // spartan format mapper: CirC -> Spartan let mut wit = Vec::new(); let mut inp = Vec::new(); - let mut trans: HashMap = HashMap::default(); // Circ -> spartan ids - let mut itrans: HashMap = HashMap::default(); // Circ -> spartan ids + let mut trans: HashMap = HashMap::default(); // Circ -> spartan ids + let mut itrans: HashMap = HashMap::default(); // spartan ids -> Circ // check modulus - let f_mod = prover_data.r1cs.modulus.modulus(); + let f_mod = prover_data.r1cs.field.modulus(); let s_mod = Integer::from_str_radix( "7237005577332262213973186563042994240857116359379907606001950938285454250989", 10, @@ -95,37 +96,22 @@ pub fn r1cs_to_spartan( let values = eval_inputs(inputs_map, prover_data); - let pf_input_order: Vec = (0..prover_data.r1cs.next_idx) - .filter(|i| prover_data.r1cs.public_idxs.contains(i)) - .collect(); + assert_eq!(values.len(), prover_data.r1cs.vars.len()); - for idx in &pf_input_order { - let sig = prover_data.r1cs.idxs_signals.get(idx).cloned().unwrap(); - - let scalar = match values.get(&sig.to_string()) { - Some(v) => val_to_scalar(v), - None => panic!("Input/witness variable does not have matching evaluation"), - }; + for (var, val) in prover_data.r1cs.vars.iter().zip(&values) { + let scalar = val_to_scalar(val); // input - itrans.insert(*idx, inp.len()); - inp.push(scalar.to_bytes()); - } - - for (sig, idx) in &prover_data.r1cs.signal_idxs { - let scalar = match values.get(&sig.to_string()) { - Some(v) => val_to_scalar(v), - None => panic!("Input/witness variable does not have matching evaluation"), - }; - - if !prover_data.r1cs.public_idxs.contains(idx) { - // witness - trans.insert(*idx, wit.len()); + itrans.insert(inp.len(), *var); + trans.insert(*var, inp.len()); + if let VarType::Inst = var.ty() { + inp.push(scalar.to_bytes()); + } else { wit.push(scalar.to_bytes()); } } - assert_eq!(wit.len() + inp.len(), prover_data.r1cs.next_idx); + assert_eq!(wit.len() + inp.len(), prover_data.r1cs.vars.len()); let num_vars = wit.len(); let const_id = wit.len(); @@ -135,10 +121,6 @@ pub fn r1cs_to_spartan( let num_inputs = inp.len(); let assn_inputs = InputsAssignment::new(&inp).unwrap(); - for (cid, sid) in itrans { - trans.insert(cid, sid + const_id + 1); - } - // circuit let mut m_a: Vec<(usize, usize, [u8; 32])> = Vec::new(); let mut m_b: Vec<(usize, usize, [u8; 32])> = Vec::new(); @@ -183,24 +165,22 @@ pub fn r1cs_to_spartan( ) } -fn eval_inputs( - inputs_map: &HashMap, - prover_data: &ProverData, -) -> HashMap { - for (input, sort) in &prover_data.precompute_inputs { - let value = inputs_map - .get(input) - .unwrap_or_else(|| panic!("No input for {}", input)); - let sort2 = value.sort(); - assert_eq!( - sort, &sort2, - "Sort mismatch for {input}. Expected\n\t{sort} but got\n\t{sort2}", - ); - } - let new_map = prover_data.precompute.eval(inputs_map); - prover_data.r1cs.check_all(&new_map); - - new_map +fn eval_inputs(inputs_map: &HashMap, prover_data: &ProverData) -> Vec { + let mut evaluator = StagedWitCompEvaluator::new(&prover_data.precompute); + let mut ffs = Vec::new(); + ffs.extend( + evaluator + .eval_stage(inputs_map.clone()) + .into_iter() + .cloned(), + ); + ffs.extend( + evaluator + .eval_stage(Default::default()) + .into_iter() + .cloned(), + ); + ffs } fn val_to_scalar(v: &Value) -> Scalar { @@ -228,7 +208,7 @@ fn int_to_scalar(i: &Integer) -> Scalar { } // circ Lc (const, monomials ) -> Vec -fn lc_to_v(lc: &Lc, const_id: usize, trans: &HashMap) -> Vec { +fn lc_to_v(lc: &Lc, const_id: usize, trans: &HashMap) -> Vec { let mut v: Vec = Vec::new(); for (k, m) in &lc.monomials { diff --git a/src/target/r1cs/trans.rs b/src/target/r1cs/trans.rs index 83126cce..a6322919 100644 --- a/src/target/r1cs/trans.rs +++ b/src/target/r1cs/trans.rs @@ -4,14 +4,12 @@ //! thesis](https://github.com/circify/circ/tree/master/doc/resources/braun-bs-thesis.pdf) //! is a good intro to how this process works. use crate::cfg::CircCfg; -use crate::ir::term::precomp::PreComp; use crate::ir::term::*; use crate::target::bitsize; use crate::target::r1cs::*; use circ_fields::FieldT; use circ_opt::FieldDivByZero; -use fxhash::FxHashSet; use log::debug; use rug::ops::Pow; use rug::Integer; @@ -39,29 +37,27 @@ enum EmbeddedTerm { } struct ToR1cs<'cfg> { - r1cs: R1cs, + r1cs: R1cs, cache: TermMap, - wit_ext: PreComp, - public_inputs: FxHashSet, next_idx: usize, zero: TermLc, one: TermLc, cfg: &'cfg CircCfg, field: FieldT, + used_vars: HashSet, } impl<'cfg> ToR1cs<'cfg> { - fn new(cfg: &'cfg CircCfg, public_inputs: FxHashSet) -> Self { + fn new(cfg: &'cfg CircCfg, precompute: precomp::PreComp, used_vars: HashSet) -> Self { let field = cfg.field().clone(); debug!("Starting R1CS back-end, field: {}", field); - let r1cs = R1cs::new(field.clone()); + let r1cs = R1cs::new(field.clone(), precompute); let zero = TermLc(pf_lit(field.new_v(0u8)), r1cs.zero()); let one = zero.clone() + 1; Self { r1cs, cache: TermMap::default(), - wit_ext: precomp::PreComp::new(), - public_inputs, + used_vars, next_idx: 0, zero, one, @@ -74,19 +70,24 @@ impl<'cfg> ToR1cs<'cfg> { /// If values are being recorded, `value` must be provided. /// /// `comp` is a term that computes the value. - fn fresh_var(&mut self, ctx: &D, comp: Term, public: bool) -> TermLc { - let n = format!("{}_n{}", ctx, self.next_idx); + fn fresh_var(&mut self, ctx: &D, comp: Term, ty: VarType) -> TermLc { + let n = if matches!(ty, VarType::Chall) { + format!("{ctx}") + } else { + format!("{ctx}_n{}", self.next_idx) + }; self.next_idx += 1; debug_assert!(matches!(check(&comp), Sort::Field(_))); - self.r1cs.add_signal(n.clone(), comp.clone()); - self.wit_ext.add_output(n.clone(), comp.clone()); - if public { - self.r1cs.publicize(&n); - } - debug!("fresh: {}", n); + self.r1cs.add_var(n.clone(), comp.clone(), ty); + debug!("fresh: {n}"); TermLc(comp, self.r1cs.signal_lc(&n)) } + /// Get a new witness. See [ToR1cs::fresh_var]. + fn fresh_wit(&mut self, ctx: &D, comp: Term) -> TermLc { + self.fresh_var(ctx, comp, VarType::FinalWit) + } + /// Enforce `x` to be bit-valued fn enforce_bit(&mut self, b: TermLc) { self.r1cs @@ -98,7 +99,7 @@ impl<'cfg> ToR1cs<'cfg> { fn fresh_bit(&mut self, ctx: &D, comp: Term) -> TermLc { debug_assert!(matches!(check(&comp), Sort::Bool)); let comp = term![Op::Ite; comp, self.one.0.clone(), self.zero.0.clone()]; - let v = self.fresh_var(ctx, comp, false); + let v = self.fresh_var(ctx, comp, VarType::FinalWit); //debug!("Fresh bit: {}", self.r1cs.format_lc(&v)); self.enforce_bit(v.clone()); v @@ -110,15 +111,13 @@ impl<'cfg> ToR1cs<'cfg> { let eqz = term![Op::Eq; x.0.clone(), self.zero.0.clone()]; // m * x - 1 + is_zero == 0 // is_zero * x == 0 - let m = self.fresh_var( + let m = self.fresh_wit( "is_zero_inv", term![Op::Ite; eqz.clone(), self.zero.0.clone(), term![PF_RECIP; x.0.clone()]], - false, ); - let is_zero = self.fresh_var( + let is_zero = self.fresh_wit( "is_zero", term![Op::Ite; eqz, self.one.0.clone(), self.zero.0.clone()], - false, ); self.r1cs .constraint(m.1, x.1.clone(), -is_zero.1.clone() + 1); @@ -208,7 +207,7 @@ impl<'cfg> ToR1cs<'cfg> { /// Return the product of `a` and `b`. fn mul(&mut self, a: TermLc, b: TermLc) -> TermLc { let mul_val = term![PF_MUL; a.0, b.0]; - let c = self.fresh_var("mul", mul_val, false); + let c = self.fresh_wit("mul", mul_val); self.r1cs.constraint(a.1, b.1, c.1.clone()); c } @@ -255,6 +254,48 @@ impl<'cfg> ToR1cs<'cfg> { self.mul(c, t - f) + f } + /// Embed this variable as + fn embed_var(&mut self, var: &Term, ty: VarType) { + assert!( + !self.cache.contains_key(var), + "already have var {}", + var.op() + ); + assert!(!matches!(ty, VarType::CWit), "Unimplemented"); + if !self.used_vars.contains(var.as_var_name()) { + return; + } + let public = matches!(ty, VarType::Inst); + match var.op() { + Op::Var(name, Sort::Bool) => { + let comp = term![Op::Ite; var.clone(), self.one.0.clone(), self.zero.0.clone()]; + let lc = self.fresh_var(name, comp, ty); + if !public { + self.enforce_bit(lc.clone()); + } + self.cache.insert(var.clone(), EmbeddedTerm::Bool(lc)); + } + Op::Var(name, Sort::BitVector(n_bits)) => { + let public = matches!(ty, VarType::Inst); + let lc = self.fresh_var( + name, + term![Op::UbvToPf(self.field.clone()); var.clone()], + ty, + ); + self.set_bv_uint(var.clone(), lc, *n_bits); + if !public { + self.get_bv_bits(var); + } + } + Op::Var(name, Sort::Field(f)) => { + assert_eq!(f, &self.field); + let lc = self.fresh_var(name, var.clone(), ty); + self.cache.insert(var.clone(), EmbeddedTerm::Field(lc)); + } + o => unreachable!("Unhandled variable operator {}", o), + } + } + fn embed(&mut self, t: Term) { debug!("Embed: {}", t); for c in PostOrderIter::new(t) { @@ -352,15 +393,7 @@ impl<'cfg> ToR1cs<'cfg> { // TODO: skip if already embedded if !self.cache.contains_key(&c) { let lc = match &c.op() { - Op::Var(name, Sort::Bool) => { - let public = self.public_inputs.contains(name); - let comp = term![Op::Ite; c.clone(), self.one.0.clone(), self.zero.0.clone()]; - let v = self.fresh_var(name, comp, public); - if !public { - self.enforce_bit(v.clone()); - } - v - } + Op::Var(..) => panic!("call embed_var instead"), Op::Const(Value::Bool(b)) => self.zero.clone() + *b as isize, Op::Eq => self.embed_eq(&c.cs()[0], &c.cs()[1]), Op::Ite => { @@ -563,18 +596,7 @@ impl<'cfg> ToR1cs<'cfg> { if let Sort::BitVector(n) = check(&bv) { if !self.cache.contains_key(&bv) { match &bv.op() { - Op::Var(name, Sort::BitVector(_)) => { - let public = self.public_inputs.contains(name); - let var = self.fresh_var( - name, - term![Op::UbvToPf(self.field.clone()); bv.clone()], - public, - ); - self.set_bv_uint(bv.clone(), var, n); - if !public { - self.get_bv_bits(&bv); - } - } + Op::Var(..) => panic!("call embed_var instead"), Op::Const(Value::BitVector(b)) => { let bit_lcs = (0..b.width()) .map(|i| self.zero.clone() + b.uint().get_bit(i as u32) as isize) @@ -708,8 +730,8 @@ impl<'cfg> ToR1cs<'cfg> { let b_bv_term = term![Op::PfToBv(n); b.0.clone()]; let q_term = term![Op::UbvToPf(self.field.clone()); term![BV_UDIV; a_bv_term.clone(), b_bv_term.clone()]]; let r_term = term![Op::UbvToPf(self.field.clone()); term![BV_UREM; a_bv_term, b_bv_term]]; - let q = self.fresh_var("div_q", q_term, false); - let r = self.fresh_var("div_r", r_term, false); + let q = self.fresh_wit("div_q", q_term); + let r = self.fresh_wit("div_r", r_term); let qb = self.bitify("div_q", &q, n, false); let rb = self.bitify("div_r", &r, n, false); self.r1cs.constraint(q.1.clone(), b.1.clone(), (a - &r).1); @@ -880,10 +902,7 @@ impl<'cfg> ToR1cs<'cfg> { // TODO: skip if already embedded if !self.cache.contains_key(&c) { let lc = match &c.op() { - Op::Var(name, Sort::Field(_)) => { - let public = self.public_inputs.contains(name); - self.fresh_var(name, c.clone(), public) - } + Op::Var(..) => panic!("call embed_var instead"), Op::Const(Value::Field(r)) => TermLc( c.clone(), self.r1cs.constant(r.as_ty_ref(&self.r1cs.modulus)), @@ -914,7 +933,7 @@ impl<'cfg> ToR1cs<'cfg> { FieldDivByZero::Incomplete => { // ix = 1 let x = self.get_pf(&c.cs()[0]).clone(); - let inv_x = self.fresh_var("recip", term![PF_RECIP; x.0.clone()], false); + let inv_x = self.fresh_wit("recip", term![PF_RECIP; x.0.clone()]); self.r1cs .constraint(x.1, inv_x.1.clone(), self.r1cs.zero() + 1); inv_x @@ -923,7 +942,7 @@ impl<'cfg> ToR1cs<'cfg> { // ixx = x let x = self.get_pf(&c.cs()[0]).clone(); let x2 = self.mul(x.clone(), x.clone()); - let inv_x = self.fresh_var("recip", term![PF_RECIP; x.0.clone()], false); + let inv_x = self.fresh_wit("recip", term![PF_RECIP; x.0.clone()]); self.r1cs.constraint(x2.1, inv_x.1.clone(), x.1); inv_x } @@ -933,15 +952,13 @@ impl<'cfg> ToR1cs<'cfg> { // zi = 0 let x = self.get_pf(&c.cs()[0]).clone(); let eqz = term![Op::Eq; x.0.clone(), self.zero.0.clone()]; - let i = self.fresh_var( + let i = self.fresh_wit( "is_zero_inv", term![Op::Ite; eqz.clone(), self.zero.0.clone(), term![PF_RECIP; x.0.clone()]], - false, ); - let z = self.fresh_var( + let z = self.fresh_wit( "is_zero", term![Op::Ite; eqz, self.one.0.clone(), self.zero.0.clone()], - false, ); self.r1cs .constraint(i.1.clone(), x.1.clone(), -z.1.clone() + 1); @@ -974,34 +991,40 @@ impl<'cfg> ToR1cs<'cfg> { /// /// * Prover data (including the R1CS instance) /// * Verifier data -pub fn to_r1cs(mut cs: Computation, cfg: &CircCfg) -> (ProverData, VerifierData) { - let assertions = cs.outputs.clone(); - let metadata = cs.metadata.clone(); - let public_inputs = metadata.public_input_names_set(); +pub fn to_r1cs(cs: &Computation, cfg: &CircCfg) -> R1cs { + let public_inputs = cs.metadata.public_input_names_set(); debug!("public inputs: {:?}", public_inputs); - let mut converter = ToR1cs::new(cfg, public_inputs); + let used_vars = extras::free_variables(term(Op::Tuple, cs.outputs.clone())); + let mut converter = ToR1cs::new(cfg, cs.precomputes.clone(), used_vars); debug!( "Term count: {}", - assertions + cs.outputs .iter() .map(|c| PostOrderIter::new(c.clone()).count()) .sum::() ); debug!("declaring inputs"); - for i in metadata.ordered_public_inputs() { - debug!("input {}", i); - converter.embed(i); + let vars = cs.metadata.interactive_vars(); + for i in &vars.instances { + converter.embed_var(i, VarType::Inst); + } + for round in &vars.rounds { + for w in &round.witnesses { + converter.embed_var(w, VarType::RoundWit); + } + for c in &round.challenges { + converter.embed_var(c, VarType::Chall); + } + converter.r1cs.end_round() + } + for w in &vars.final_witnesses { + converter.embed_var(w, VarType::FinalWit); } debug!("Printing assertions"); - for c in assertions { - converter.assert(c); + for c in &cs.outputs { + converter.assert(c.clone()); } - debug!("r1cs public inputs: {:?}", converter.r1cs.public_idxs,); - cs.precomputes = cs.precomputes.sequential_compose(&converter.wit_ext); - let r1cs = converter.r1cs; - let verifier_data = r1cs.verifier_data(&cs); - let prover_data = r1cs.prover_data(&cs); - (prover_data, verifier_data) + converter.r1cs } #[cfg(test)] @@ -1015,14 +1038,14 @@ pub mod test { use fxhash::FxHashMap; use quickcheck_macros::quickcheck; - fn to_r1cs_dflt(cs: Computation) -> (ProverData, VerifierData) { - to_r1cs(cs, &CircCfg::default()) + fn to_r1cs_dflt(cs: Computation) -> R1cs { + to_r1cs(&cs, &CircCfg::default()) } - fn to_r1cs_mod17(cs: Computation) -> (ProverData, VerifierData) { + fn to_r1cs_mod17(cs: Computation) -> R1cs { let mut opt = crate::cfg::CircOpt::default(); opt.field.custom_modulus = "17".into(); - to_r1cs(cs, &CircCfg::from(opt)) + to_r1cs(&cs, &CircCfg::from(opt)) } fn init() { @@ -1048,10 +1071,8 @@ pub mod test { leaf_term(Op::Var("b".to_owned(), Sort::Bool)), ], ); - let (pd, _) = to_r1cs_mod17(cs); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs_mod17(cs); + r1cs.check_all(&values); } #[quickcheck] @@ -1062,10 +1083,8 @@ pub mod test { term![Op::Not; t] }; let cs = Computation::from_constraint_system_parts(vec![t], Vec::new()); - let (pd, _) = to_r1cs_dflt(cs); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs_dflt(cs); + r1cs.check_all(&values); } #[quickcheck] @@ -1076,10 +1095,8 @@ pub mod test { crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs); crate::ir::opt::tuple::eliminate_tuples(&mut cs); let cfg = CircCfg::default(); - let (pd, _) = to_r1cs(cs, &cfg); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs(&cs, &cfg); + r1cs.check_all(&values); } #[quickcheck] @@ -1088,12 +1105,10 @@ pub mod test { let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; let cs = Computation::from_constraint_system_parts(vec![t], Vec::new()); let cfg = CircCfg::default(); - let (pd, _) = to_r1cs(cs, &cfg); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); - let r1cs2 = reduce_linearities(pd.r1cs, &cfg); - r1cs2.check_all(&extended_values); + let r1cs = to_r1cs(&cs, &cfg); + r1cs.check_all(&values); + let r1cs2 = reduce_linearities(r1cs, &cfg); + r1cs2.check_all(&values); } #[quickcheck] @@ -1104,12 +1119,10 @@ pub mod test { crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs); crate::ir::opt::tuple::eliminate_tuples(&mut cs); let cfg = CircCfg::default(); - let (pd, _) = to_r1cs(cs, &cfg); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); - let r1cs2 = reduce_linearities(pd.r1cs, &cfg); - r1cs2.check_all(&extended_values); + let r1cs = to_r1cs(&cs, &cfg); + r1cs.check_all(&values); + let r1cs2 = reduce_linearities(r1cs, &cfg); + r1cs2.check_all(&values); } #[test] @@ -1126,10 +1139,8 @@ pub mod test { term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]]], vec![leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))], ); - let (pd, _) = to_r1cs_dflt(cs); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs_dflt(cs); + r1cs.check_all(&values); } #[test] @@ -1143,12 +1154,10 @@ pub mod test { let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; let cs = Computation::from_constraint_system_parts(vec![t], vec![]); let cfg = CircCfg::default(); - let (pd, _) = to_r1cs(cs, &cfg); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); - let r1cs2 = reduce_linearities(pd.r1cs, &cfg); - r1cs2.check_all(&extended_values); + let r1cs = to_r1cs(&cs, &cfg); + r1cs.check_all(&values); + let r1cs2 = reduce_linearities(r1cs, &cfg); + r1cs2.check_all(&values); } fn pf_dflt(i: isize) -> Term { @@ -1158,10 +1167,8 @@ pub mod test { fn const_test(term: Term) { let mut cs = Computation::new(); cs.assert(term); - let (pd, _) = to_r1cs_dflt(cs); - let precomp = pd.precompute; - let extended_values = precomp.eval(&Default::default()); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs_dflt(cs); + r1cs.check_all(&Default::default()); } #[test] @@ -1246,6 +1253,7 @@ pub mod test { #[test] fn pf2bv_lit() { + init(); const_test(term![ Op::Eq; term![Op::PfToBv(4); pf_dflt(8)], @@ -1282,9 +1290,7 @@ pub mod test { ], ); crate::ir::opt::tuple::eliminate_tuples(&mut cs); - let (pd, _) = to_r1cs_mod17(cs); - let precomp = pd.precompute; - let extended_values = precomp.eval(&values); - pd.r1cs.check_all(&extended_values); + let r1cs = to_r1cs_mod17(cs); + r1cs.check_all(&values); } } diff --git a/src/target/r1cs/wit_comp.rs b/src/target/r1cs/wit_comp.rs new file mode 100644 index 00000000..71660b75 --- /dev/null +++ b/src/target/r1cs/wit_comp.rs @@ -0,0 +1,326 @@ +//! A multi-stage R1CS witness evaluator. + +use crate::ir::term::*; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use serde::{Deserialize, Serialize}; + +use log::trace; + +/// A witness computation that proceeds in stages. +/// +/// In each stage: +/// * it takes a partial assignment +/// * it returns a vector of field values +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct StagedWitComp { + vars: HashSet, + stages: Vec, + steps: Vec<(Op, usize)>, + step_args: Vec, + ouput_steps: Vec, + // we don't serialize the cache; it's just used during construction, and terms are expensive to + // serialize. + #[serde(skip)] + term_to_step: TermMap, +} + +/// Specifies a stage. +#[derive(Debug, Serialize, Deserialize)] +pub struct Stage { + inputs: HashMap, + num_outputs: usize, +} + +/// Builder interface +impl StagedWitComp { + /// Add a new stage. + #[allow(clippy::uninlined_format_args)] + pub fn add_stage(&mut self, inputs: HashMap, output_values: Vec) { + let stage = Stage { + inputs, + num_outputs: output_values.len(), + }; + for input in stage.inputs.keys() { + debug_assert!(!self.vars.contains(input), "Duplicate input {}", input); + } + self.vars.extend(stage.inputs.keys().cloned()); + self.stages.push(stage); + let already_have: TermSet = self.term_to_step.keys().cloned().collect(); + for t in PostOrderIter::from_roots_and_skips(output_values.clone(), already_have) { + self.add_step(t); + } + for t in output_values { + self.ouput_steps.push(*self.term_to_step.get(&t).unwrap()); + } + } + + fn add_step(&mut self, term: Term) { + debug_assert!(!self.term_to_step.contains_key(&term)); + let step_idx = self.steps.len(); + if let Op::Var(name, _) = term.op() { + debug_assert!(self.vars.contains(name)); + } + for child in term.cs() { + let child_step = self.term_to_step.get(child).unwrap(); + self.step_args.push(*child_step); + } + self.steps.push((term.op().clone(), self.step_args.len())); + self.term_to_step.insert(term, step_idx); + } + + /// How many stages are there? + pub fn stage_sizes(&self) -> impl Iterator + '_ { + self.stages.iter().map(|s| s.num_outputs) + } + + /// How many inputs are there for this stage? + pub fn num_stage_inputs(&self, n: usize) -> usize { + self.stages[n].inputs.len() + } +} + +/// Evaluator interface +impl StagedWitComp { + fn step_args(&self, step_idx: usize) -> impl Iterator + '_ { + assert!(step_idx < self.steps.len()); + let args_end = self.steps[step_idx].1; + let args_start = if step_idx == 0 { + 0 + } else { + self.steps[step_idx - 1].1 + }; + (args_start..args_end).map(move |step_arg_idx| self.step_args[step_arg_idx]) + } +} + +/// Evaluates a staged witness computation. +#[derive(Debug)] +pub struct StagedWitCompEvaluator<'a> { + comp: &'a StagedWitComp, + variable_values: HashMap, + step_values: Vec, + stages_evaluated: usize, + outputs_evaluted: usize, +} + +impl<'a> StagedWitCompEvaluator<'a> { + /// Create an empty witness computation. + pub fn new(comp: &'a StagedWitComp) -> Self { + Self { + comp, + variable_values: Default::default(), + step_values: Default::default(), + stages_evaluated: Default::default(), + outputs_evaluted: 0, + } + } + /// Have all stages been evaluated? + pub fn is_done(&self) -> bool { + self.stages_evaluated == self.comp.stages.len() + } + fn eval_step(&mut self) { + let next_step_idx = self.step_values.len(); + assert!(next_step_idx < self.comp.steps.len()); + let op = &self.comp.steps[next_step_idx].0; + let args: Vec<&Value> = self + .comp + .step_args(next_step_idx) + .map(|i| &self.step_values[i]) + .collect(); + let value = eval_op(op, &args, &self.variable_values); + trace!( + "Eval step {}: {} on {:?} -> {}", + next_step_idx, + op, + args, + value + ); + self.step_values.push(value); + } + /// Evaluate one stage. + pub fn eval_stage(&mut self, inputs: HashMap) -> Vec<&Value> { + trace!( + "Beginning stage {}/{}", + self.stages_evaluated, + self.comp.stages.len() + ); + debug_assert!(self.stages_evaluated < self.comp.stages.len()); + let stage = &self.comp.stages[self.stages_evaluated]; + let num_outputs = stage.num_outputs; + self.variable_values.extend(inputs); + if num_outputs > 0 { + let max_step = (0..num_outputs) + .map(|i| { + let new_output_i = i + self.outputs_evaluted; + self.comp.ouput_steps[new_output_i] + }) + .max() + .unwrap(); + while self.step_values.len() <= max_step { + self.eval_step(); + } + } + self.outputs_evaluted += num_outputs; + self.stages_evaluated += 1; + let mut out = Vec::new(); + for output_step in + &self.comp.ouput_steps[self.outputs_evaluted - num_outputs..self.outputs_evaluted] + { + out.push(&self.step_values[*output_step]); + } + out + } +} + +#[cfg(test)] +mod test { + + use rug::Integer; + + use super::*; + use circ_fields::FieldT; + + fn mk_inputs(v: Vec<(String, Sort)>) -> HashMap { + v.into_iter().collect() + } + + #[test] + fn one_const() { + let mut comp = StagedWitComp::default(); + let field = FieldT::from(Integer::from(7)); + comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]); + + let mut evaluator = StagedWitCompEvaluator::new(&comp); + + let output = evaluator.eval_stage(Default::default()); + let ex_output: &[usize] = &[0]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + assert!(evaluator.is_done()); + } + + #[test] + fn many_const() { + let mut comp = StagedWitComp::default(); + let field = FieldT::from(Integer::from(7)); + comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]); + comp.add_stage( + mk_inputs(vec![]), + vec![pf_lit(field.new_v(1)), pf_lit(field.new_v(4))], + ); + comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(6))]); + comp.add_stage(mk_inputs(vec![]), vec![pf_lit(field.new_v(0))]); + + let mut evaluator = StagedWitCompEvaluator::new(&comp); + + let output = evaluator.eval_stage(Default::default()); + let ex_output: &[usize] = &[0]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + let output = evaluator.eval_stage(Default::default()); + let ex_output: &[usize] = &[1, 4]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + let output = evaluator.eval_stage(Default::default()); + let ex_output: &[usize] = &[6]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + let output = evaluator.eval_stage(Default::default()); + let ex_output: &[usize] = &[0]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + assert!(evaluator.is_done()); + } + + #[test] + fn vars_one_stage() { + let mut comp = StagedWitComp::default(); + let field = FieldT::from(Integer::from(7)); + comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]), + vec![ + leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), + term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + ]); + + let mut evaluator = StagedWitCompEvaluator::new(&comp); + + let output = evaluator.eval_stage( + vec![ + ("a".into(), Value::Bool(true)), + ("b".into(), Value::Field(field.new_v(5))), + ] + .into_iter() + .collect(), + ); + let ex_output: &[usize] = &[5, 1]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + assert!(evaluator.is_done()); + } + + #[test] + fn vars_many_stages() { + let mut comp = StagedWitComp::default(); + let field = FieldT::from(Integer::from(7)); + comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]), + vec![ + leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), + term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + ]); + comp.add_stage(mk_inputs(vec![("c".into(), Sort::Field(field.clone()))]), + vec![ + term![PF_ADD; + leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), + leaf_term(Op::Var("c".into(), Sort::Field(field.clone())))], + term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(0)), pf_lit(field.new_v(1))], + ]); + + let mut evaluator = StagedWitCompEvaluator::new(&comp); + + let output = evaluator.eval_stage( + vec![ + ("a".into(), Value::Bool(true)), + ("b".into(), Value::Field(field.new_v(5))), + ] + .into_iter() + .collect(), + ); + let ex_output: &[usize] = &[5, 1]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + let output = evaluator.eval_stage( + vec![("c".into(), Value::Field(field.new_v(3)))] + .into_iter() + .collect(), + ); + let ex_output: &[usize] = &[1, 1, 0]; + assert_eq!(output.len(), ex_output.len()); + for i in 0..ex_output.len() { + assert_eq!(output[i], &Value::Field(field.new_v(ex_output[i])), "{i}"); + } + + assert!(evaluator.is_done()); + } +} diff --git a/util.py b/util.py index 2da2d97c..f327ae66 100644 --- a/util.py +++ b/util.py @@ -4,8 +4,8 @@ from os import path # Gloable variables feature_path = ".features.txt" mode_path = ".mode.txt" -cargo_features = {"aby", "c", "lp", "r1cs", - "smt", "zok", "datalog", "bellman", "spartan", "kahip", "kahypar"} +cargo_features = {"aby", "c", "lp", "r1cs", "kahip", "kahypar", + "smt", "zok", "datalog", "bellman", "spartan"} # Environment variables ABY_SOURCE = "./../ABY"