diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..7ec1149 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,42 @@ +name: Publish Package to npmjs +on: + release: + types: [published] + workflow_dispatch: + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ github.ref_name }} + # Setup Node.js + - uses: actions/setup-node@v3 + with: + node-version: 18 + registry-url: "https://registry.npmjs.org" + # Setup Rust + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2022-10-31 + - run: rustup component add rust-src --toolchain nightly-2022-10-31-x86_64-unknown-linux-gnu + - run: rustup target add x86_64-apple-darwin + # Install circom-secq + - uses: GuillaumeFalourd/clone-github-repo-action@v2 + with: + owner: "DanTehrani" + repository: "circom-secq" + - run: cd circom-secq && cargo build --release && cargo install --path circom + # Install wasm-pack + - uses: jetli/wasm-pack-action@v0.4.0 + with: + version: "0.10.3" + - run: cargo test --release + - run: yarn + - run: yarn build + - run: yarn test + - run: npm publish + working-directory: ./packages/lib + env: + NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index 7089b92..4736843 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,4 +5,7 @@ members = [ # "packages/poseidon", "packages/hoplite", "packages/hoplite_circuit", + "packages/poseidon", + "packages/Spartan-secq", + "packages/circuit_reader", ] \ No newline at end of file diff --git a/package.json b/package.json index 4f80948..c12a273 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,8 @@ "repository": "https://github.com/DanTehrani/spartan-wasm.git", "author": "Daniel Tehrani ", "scripts": { - "build": "sh ./scripts/build.sh && lerna run build" + "build": "sh ./scripts/build.sh && lerna run build", + "test": "sh ./scripts/test.sh" }, "devDependencies": { "@types/jest": "^29.2.4", diff --git a/packages/Spartan-secq/CODE_OF_CONDUCT.md b/packages/Spartan-secq/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..f9ba8cf --- /dev/null +++ b/packages/Spartan-secq/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/packages/Spartan-secq/CONTRIBUTING.md b/packages/Spartan-secq/CONTRIBUTING.md new file mode 100644 index 0000000..d82cc36 --- /dev/null +++ b/packages/Spartan-secq/CONTRIBUTING.md @@ -0,0 +1,12 @@ +This project welcomes contributions and suggestions. Most contributions require you to +agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the +instructions provided by the bot. You will only need to do this once across all repositories using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. \ No newline at end of file diff --git a/packages/Spartan-secq/Cargo.toml b/packages/Spartan-secq/Cargo.toml new file mode 100644 index 0000000..57d064a --- /dev/null +++ b/packages/Spartan-secq/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "spartan" +version = "0.7.1" +authors = ["Srinath Setty "] +edition = "2021" +description = "High-speed zkSNARKs without trusted setup" +documentation = "https://docs.rs/spartan/" +readme = "README.md" +repository = "https://github.com/microsoft/Spartan" +license-file = "LICENSE" +keywords = ["zkSNARKs", "cryptography", "proofs"] + +[dependencies] +num-bigint-dig = "^0.7" +secq256k1 = { path = "../secq256k1" } +merlin = "3.0.0" +rand = "0.7.3" +digest = "0.8.1" +sha3 = "0.8.2" +byteorder = "1.3.4" +rayon = { version = "1.3.0", optional = true } +serde = { version = "1.0.106", features = ["derive"] } +bincode = "1.2.1" +subtle = { version = "2.4", default-features = false } +rand_core = { version = "0.6", default-features = false } +zeroize = { version = "1", default-features = false } +itertools = "0.10.0" +colored = "2.0.0" +flate2 = "1.0.14" +thiserror = "1.0" +num-traits = "0.2.15" +hex-literal = { version = "0.3" } +multiexp = "0.2.2" + +[dev-dependencies] +criterion = "0.3.1" + +[lib] +name = "libspartan" +path = "src/lib.rs" +crate-type = ["cdylib", "rlib"] + +[[bin]] +name = "snark" +path = "profiler/snark.rs" + +[[bin]] +name = "nizk" +path = "profiler/nizk.rs" + +[[bench]] +name = "snark" +harness = false + +[[bench]] +name = "nizk" +harness = false diff --git a/packages/Spartan-secq/LICENSE b/packages/Spartan-secq/LICENSE new file mode 100644 index 0000000..9e841e7 --- /dev/null +++ b/packages/Spartan-secq/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/packages/Spartan-secq/README.md b/packages/Spartan-secq/README.md new file mode 100644 index 0000000..0961bee --- /dev/null +++ b/packages/Spartan-secq/README.md @@ -0,0 +1,10 @@ +## Fork of [Spartan](https://github.com/microsoft/Spartan) +_This fork is still under development._ + +Modify Spartan to operate over the **base field** of secp256k1. + +### Changes from the original Spartan +- Use the secq256k1 crate instead of curve25519-dalek +- Modify values in scalar.rs (originally ristretto255.rs) + +Please refer to [spartan-ecdsa](https://github.com/personaelabs/spartan-ecdsa) for development status. diff --git a/packages/Spartan-secq/SECURITY.md b/packages/Spartan-secq/SECURITY.md new file mode 100644 index 0000000..e0dfff5 --- /dev/null +++ b/packages/Spartan-secq/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/packages/Spartan-secq/benches/nizk.rs b/packages/Spartan-secq/benches/nizk.rs new file mode 100644 index 0000000..77846cc --- /dev/null +++ b/packages/Spartan-secq/benches/nizk.rs @@ -0,0 +1,92 @@ +#![allow(clippy::assertions_on_result_states)] +extern crate byteorder; +extern crate core; +extern crate criterion; +extern crate digest; +extern crate libspartan; +extern crate merlin; +extern crate rand; +extern crate sha3; + +use libspartan::{Instance, NIZKGens, NIZK}; +use merlin::Transcript; + +use criterion::*; + +fn nizk_prove_benchmark(c: &mut Criterion) { + for &s in [10, 12, 16].iter() { + let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); + let mut group = c.benchmark_group("NIZK_prove_benchmark"); + group.plot_config(plot_config); + + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + let gens = NIZKGens::new(num_cons, num_vars, num_inputs); + + let name = format!("NIZK_prove_{}", num_vars); + group.bench_function(&name, move |b| { + b.iter(|| { + let mut prover_transcript = Transcript::new(b"example"); + NIZK::prove( + black_box(&inst), + black_box(vars.clone()), + black_box(&inputs), + black_box(&gens), + black_box(&mut prover_transcript), + ); + }); + }); + group.finish(); + } +} + +fn nizk_verify_benchmark(c: &mut Criterion) { + for &s in [10, 12, 16].iter() { + let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); + let mut group = c.benchmark_group("NIZK_verify_benchmark"); + group.plot_config(plot_config); + + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + let gens = NIZKGens::new(num_cons, num_vars, num_inputs); + + // produce a proof of satisfiability + let mut prover_transcript = Transcript::new(b"example"); + let proof = NIZK::prove(&inst, vars, &inputs, &gens, &mut prover_transcript); + + let name = format!("NIZK_verify_{}", num_cons); + group.bench_function(&name, move |b| { + b.iter(|| { + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify( + black_box(&inst), + black_box(&inputs), + black_box(&mut verifier_transcript), + black_box(&gens) + ) + .is_ok()); + }); + }); + group.finish(); + } +} + +fn set_duration() -> Criterion { + Criterion::default().sample_size(10) +} + +criterion_group! { +name = benches_nizk; +config = set_duration(); +targets = nizk_prove_benchmark, nizk_verify_benchmark +} + +criterion_main!(benches_nizk); diff --git a/packages/Spartan-secq/benches/snark.rs b/packages/Spartan-secq/benches/snark.rs new file mode 100644 index 0000000..9b6c67e --- /dev/null +++ b/packages/Spartan-secq/benches/snark.rs @@ -0,0 +1,131 @@ +#![allow(clippy::assertions_on_result_states)] +extern crate libspartan; +extern crate merlin; + +use libspartan::{Instance, SNARKGens, SNARK}; +use merlin::Transcript; + +use criterion::*; + +fn snark_encode_benchmark(c: &mut Criterion) { + for &s in [10, 12, 16].iter() { + let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); + let mut group = c.benchmark_group("SNARK_encode_benchmark"); + group.plot_config(plot_config); + + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + let (inst, _vars, _inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // produce public parameters + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons); + + // produce a commitment to R1CS instance + let name = format!("SNARK_encode_{}", num_cons); + group.bench_function(&name, move |b| { + b.iter(|| { + SNARK::encode(black_box(&inst), black_box(&gens)); + }); + }); + group.finish(); + } +} + +fn snark_prove_benchmark(c: &mut Criterion) { + for &s in [10, 12, 16].iter() { + let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); + let mut group = c.benchmark_group("SNARK_prove_benchmark"); + group.plot_config(plot_config); + + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // produce public parameters + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons); + + // produce a commitment to R1CS instance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a proof + let name = format!("SNARK_prove_{}", num_cons); + group.bench_function(&name, move |b| { + b.iter(|| { + let mut prover_transcript = Transcript::new(b"example"); + SNARK::prove( + black_box(&inst), + black_box(&comm), + black_box(&decomm), + black_box(vars.clone()), + black_box(&inputs), + black_box(&gens), + black_box(&mut prover_transcript), + ); + }); + }); + group.finish(); + } +} + +fn snark_verify_benchmark(c: &mut Criterion) { + for &s in [10, 12, 16].iter() { + let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); + let mut group = c.benchmark_group("SNARK_verify_benchmark"); + group.plot_config(plot_config); + + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // produce public parameters + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons); + + // produce a commitment to R1CS instance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a proof of satisfiability + let mut prover_transcript = Transcript::new(b"example"); + let proof = SNARK::prove( + &inst, + &comm, + &decomm, + vars, + &inputs, + &gens, + &mut prover_transcript, + ); + + // verify the proof + let name = format!("SNARK_verify_{}", num_cons); + group.bench_function(&name, move |b| { + b.iter(|| { + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify( + black_box(&comm), + black_box(&inputs), + black_box(&mut verifier_transcript), + black_box(&gens) + ) + .is_ok()); + }); + }); + group.finish(); + } +} + +fn set_duration() -> Criterion { + Criterion::default().sample_size(10) +} + +criterion_group! { +name = benches_snark; +config = set_duration(); +targets = snark_encode_benchmark, snark_prove_benchmark, snark_verify_benchmark +} + +criterion_main!(benches_snark); diff --git a/packages/Spartan-secq/examples/cubic.rs b/packages/Spartan-secq/examples/cubic.rs new file mode 100644 index 0000000..8dd1dc8 --- /dev/null +++ b/packages/Spartan-secq/examples/cubic.rs @@ -0,0 +1,147 @@ +//! Demonstrates how to produces a proof for canonical cubic equation: `x^3 + x + 5 = y`. +//! The example is described in detail [here]. +//! +//! The R1CS for this problem consists of the following 4 constraints: +//! `Z0 * Z0 - Z1 = 0` +//! `Z1 * Z0 - Z2 = 0` +//! `(Z2 + Z0) * 1 - Z3 = 0` +//! `(Z3 + 5) * 1 - I0 = 0` +//! +//! [here]: https://medium.com/@VitalikButerin/quadratic-arithmetic-programs-from-zero-to-hero-f6d558cea649 +#![allow(clippy::assertions_on_result_states)] +use libspartan::{InputsAssignment, Instance, SNARKGens, VarsAssignment, SNARK}; +use merlin::Transcript; +use rand_core::OsRng; +use secq256k1::elliptic_curve::Field; +use secq256k1::Scalar; + +#[allow(non_snake_case)] +fn produce_r1cs() -> ( + usize, + usize, + usize, + usize, + Instance, + VarsAssignment, + InputsAssignment, +) { + // parameters of the R1CS instance + let num_cons = 4; + let num_vars = 4; + let num_inputs = 1; + let num_non_zero_entries = 8; + + // We will encode the above constraints into three matrices, where + // the coefficients in the matrix are in the little-endian byte order + let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new(); + let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new(); + let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new(); + + let one: [u8; 32] = Scalar::ONE.to_bytes().into(); + + // R1CS is a set of three sparse matrices A B C, where is a row for every + // constraint and a column for every entry in z = (vars, 1, inputs) + // An R1CS instance is satisfiable iff: + // Az \circ Bz = Cz, where z = (vars, 1, inputs) + + // constraint 0 entries in (A,B,C) + // constraint 0 is Z0 * Z0 - Z1 = 0. + A.push((0, 0, one)); + B.push((0, 0, one)); + C.push((0, 1, one)); + + // constraint 1 entries in (A,B,C) + // constraint 1 is Z1 * Z0 - Z2 = 0. + A.push((1, 1, one)); + B.push((1, 0, one)); + C.push((1, 2, one)); + + // constraint 2 entries in (A,B,C) + // constraint 2 is (Z2 + Z0) * 1 - Z3 = 0. + A.push((2, 2, one)); + A.push((2, 0, one)); + B.push((2, num_vars, one)); + C.push((2, 3, one)); + + // constraint 3 entries in (A,B,C) + // constraint 3 is (Z3 + 5) * 1 - I0 = 0. + A.push((3, 3, one)); + A.push((3, num_vars, Scalar::from(5u32).to_bytes().into())); + B.push((3, num_vars, one)); + C.push((3, num_vars + 1, one)); + + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C).unwrap(); + + // compute a satisfying assignment + let mut csprng: OsRng = OsRng; + let z0 = Scalar::random(&mut csprng); + let z1 = z0 * z0; // constraint 0 + let z2 = z1 * z0; // constraint 1 + let z3 = z2 + z0; // constraint 2 + let i0 = z3 + Scalar::from(5u32); // constraint 3 + + // create a VarsAssignment + let mut vars: Vec<[u8; 32]> = vec![Scalar::ZERO.to_bytes().into(); num_vars]; + vars[0] = z0.to_bytes().into(); + vars[1] = z1.to_bytes().into(); + vars[2] = z2.to_bytes().into(); + vars[3] = z3.to_bytes().into(); + let assignment_vars = VarsAssignment::new(&vars).unwrap(); + + // create an InputsAssignment + let mut inputs: Vec<[u8; 32]> = vec![Scalar::ZERO.to_bytes().into(); num_inputs]; + inputs[0] = i0.to_bytes().into(); + let assignment_inputs = InputsAssignment::new(&inputs).unwrap(); + + // check if the instance we created is satisfiable + let res = inst.is_sat(&assignment_vars, &assignment_inputs); + assert!(res.unwrap(), "should be satisfied"); + + ( + num_cons, + num_vars, + num_inputs, + num_non_zero_entries, + inst, + assignment_vars, + assignment_inputs, + ) +} + +fn main() { + // produce an R1CS instance + let ( + num_cons, + num_vars, + num_inputs, + num_non_zero_entries, + inst, + assignment_vars, + assignment_inputs, + ) = produce_r1cs(); + + // produce public parameters + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_non_zero_entries); + + // create a commitment to the R1CS instance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a proof of satisfiability + let mut prover_transcript = Transcript::new(b"snark_example"); + let proof = SNARK::prove( + &inst, + &comm, + &decomm, + assignment_vars, + &assignment_inputs, + &gens, + &mut prover_transcript, + ); + + // verify the proof of satisfiability + let mut verifier_transcript = Transcript::new(b"snark_example"); + assert!(proof + .verify(&comm, &assignment_inputs, &mut verifier_transcript, &gens) + .is_ok()); + println!("proof verification successful!"); +} diff --git a/packages/Spartan-secq/profiler/nizk.rs b/packages/Spartan-secq/profiler/nizk.rs new file mode 100644 index 0000000..e2d3a15 --- /dev/null +++ b/packages/Spartan-secq/profiler/nizk.rs @@ -0,0 +1,52 @@ +#![allow(non_snake_case)] +#![allow(clippy::assertions_on_result_states)] + +extern crate flate2; +extern crate libspartan; +extern crate merlin; +extern crate rand; + +use flate2::{write::ZlibEncoder, Compression}; +use libspartan::{Instance, NIZKGens, NIZK}; +use merlin::Transcript; + +fn print(msg: &str) { + let star = "* "; + println!("{:indent$}{}{}", "", star, msg, indent = 2); +} + +pub fn main() { + // the list of number of variables (and constraints) in an R1CS instance + let inst_sizes = vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]; + + println!("Profiler:: NIZK"); + for &s in inst_sizes.iter() { + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + + // produce a synthetic R1CSInstance + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // produce public generators + let gens = NIZKGens::new(num_cons, num_vars, num_inputs); + + // produce a proof of satisfiability + let mut prover_transcript = Transcript::new(b"nizk_example"); + let proof = NIZK::prove(&inst, vars, &inputs, &gens, &mut prover_transcript); + + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + bincode::serialize_into(&mut encoder, &proof).unwrap(); + let proof_encoded = encoder.finish().unwrap(); + let msg_proof_len = format!("NIZK::proof_compressed_len {:?}", proof_encoded.len()); + print(&msg_proof_len); + + // verify the proof of satisfiability + let mut verifier_transcript = Transcript::new(b"nizk_example"); + assert!(proof + .verify(&inst, &inputs, &mut verifier_transcript, &gens) + .is_ok()); + + println!(); + } +} diff --git a/packages/Spartan-secq/profiler/snark.rs b/packages/Spartan-secq/profiler/snark.rs new file mode 100644 index 0000000..b347480 --- /dev/null +++ b/packages/Spartan-secq/profiler/snark.rs @@ -0,0 +1,62 @@ +#![allow(non_snake_case)] +#![allow(clippy::assertions_on_result_states)] + +extern crate flate2; +extern crate libspartan; +extern crate merlin; + +use flate2::{write::ZlibEncoder, Compression}; +use libspartan::{Instance, SNARKGens, SNARK}; +use merlin::Transcript; + +fn print(msg: &str) { + let star = "* "; + println!("{:indent$}{}{}", "", star, msg, indent = 2); +} + +pub fn main() { + // the list of number of variables (and constraints) in an R1CS instance + let inst_sizes = vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]; + + println!("Profiler:: SNARK"); + for &s in inst_sizes.iter() { + let num_vars = (2_usize).pow(s as u32); + let num_cons = num_vars; + let num_inputs = 10; + + // produce a synthetic R1CSInstance + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // produce public generators + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons); + + // create a commitment to R1CSInstance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a proof of satisfiability + let mut prover_transcript = Transcript::new(b"snark_example"); + let proof = SNARK::prove( + &inst, + &comm, + &decomm, + vars, + &inputs, + &gens, + &mut prover_transcript, + ); + + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + bincode::serialize_into(&mut encoder, &proof).unwrap(); + let proof_encoded = encoder.finish().unwrap(); + let msg_proof_len = format!("SNARK::proof_compressed_len {:?}", proof_encoded.len()); + print(&msg_proof_len); + + // verify the proof of satisfiability + let mut verifier_transcript = Transcript::new(b"snark_example"); + assert!(proof + .verify(&comm, &inputs, &mut verifier_transcript, &gens) + .is_ok()); + + println!(); + } +} diff --git a/packages/Spartan-secq/rustfmt.toml b/packages/Spartan-secq/rustfmt.toml new file mode 100644 index 0000000..7b20d96 --- /dev/null +++ b/packages/Spartan-secq/rustfmt.toml @@ -0,0 +1,4 @@ +edition = "2018" +tab_spaces = 2 +newline_style = "Unix" +use_try_shorthand = true diff --git a/packages/Spartan-secq/src/bin/mont_params.rs b/packages/Spartan-secq/src/bin/mont_params.rs new file mode 100644 index 0000000..4e261ef --- /dev/null +++ b/packages/Spartan-secq/src/bin/mont_params.rs @@ -0,0 +1,54 @@ +use hex_literal::hex; +use num_bigint_dig::{BigInt, BigUint, ModInverse, ToBigInt}; +use num_traits::{FromPrimitive, ToPrimitive}; +use std::ops::Neg; + +fn get_words(n: &BigUint) -> [u64; 4] { + let mut words = [0u64; 4]; + for i in 0..4 { + let word = n.clone() >> (64 * i) & BigUint::from(0xffffffffffffffffu64); + words[i] = word.to_u64().unwrap(); + } + words +} + +fn render_hex(label: String, words: &[u64; 4]) { + println!("// {}", label); + for word in words { + println!("0x{:016x},", word); + } +} + +fn main() { + let modulus = BigUint::from_bytes_be(&hex!( + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" + )); + + let r = BigUint::from_u8(2) + .unwrap() + .modpow(&BigUint::from_u64(256).unwrap(), &modulus); + + let r2 = BigUint::from_u8(2) + .unwrap() + .modpow(&BigUint::from_u64(512).unwrap(), &modulus); + + let r3 = BigUint::from_u8(2) + .unwrap() + .modpow(&BigUint::from_u64(768).unwrap(), &modulus); + + let two_pow_64 = BigUint::from_u128(18446744073709551616u128).unwrap(); + let one = BigInt::from_u8(1).unwrap(); + + let inv = modulus + .clone() + .mod_inverse(&two_pow_64) + .unwrap() + .neg() + .modpow(&one, &two_pow_64.to_bigint().unwrap()); + + render_hex("Modulus".to_string(), &get_words(&modulus)); + render_hex("R".to_string(), &get_words(&r)); + render_hex("R2".to_string(), &get_words(&r2)); + render_hex("R3".to_string(), &get_words(&r3)); + render_hex("INV".to_string(), &get_words(&inv.to_biguint().unwrap())); +} diff --git a/packages/Spartan-secq/src/commitments.rs b/packages/Spartan-secq/src/commitments.rs new file mode 100644 index 0000000..ab6c05b --- /dev/null +++ b/packages/Spartan-secq/src/commitments.rs @@ -0,0 +1,96 @@ +use super::group::{GroupElement, VartimeMultiscalarMul}; +use super::scalar::Scalar; +use digest::{ExtendableOutput, Input}; +use secq256k1::AffinePoint; +use sha3::Shake256; +use std::io::Read; + +#[derive(Debug)] +pub struct MultiCommitGens { + pub n: usize, + pub G: Vec, + pub h: GroupElement, +} + +impl MultiCommitGens { + pub fn new(n: usize, label: &[u8]) -> Self { + let mut shake = Shake256::default(); + shake.input(label); + shake.input(AffinePoint::generator().compress().as_bytes()); + + let mut reader = shake.xof_result(); + let mut gens: Vec = Vec::new(); + let mut uniform_bytes = [0u8; 128]; + for _ in 0..n + 1 { + reader.read_exact(&mut uniform_bytes).unwrap(); + gens.push(AffinePoint::from_uniform_bytes(&uniform_bytes)); + } + + MultiCommitGens { + n, + G: gens[..n].to_vec(), + h: gens[n], + } + } + + pub fn clone(&self) -> MultiCommitGens { + MultiCommitGens { + n: self.n, + h: self.h, + G: self.G.clone(), + } + } + + pub fn scale(&self, s: &Scalar) -> MultiCommitGens { + MultiCommitGens { + n: self.n, + h: self.h, + G: (0..self.n).map(|i| s * self.G[i]).collect(), + } + } + + pub fn split_at(&self, mid: usize) -> (MultiCommitGens, MultiCommitGens) { + let (G1, G2) = self.G.split_at(mid); + + ( + MultiCommitGens { + n: G1.len(), + G: G1.to_vec(), + h: self.h, + }, + MultiCommitGens { + n: G2.len(), + G: G2.to_vec(), + h: self.h, + }, + ) + } +} + +pub trait Commitments { + fn commit(&self, blind: &Scalar, gens_n: &MultiCommitGens) -> GroupElement; +} + +impl Commitments for Scalar { + fn commit(&self, blind: &Scalar, gens_n: &MultiCommitGens) -> GroupElement { + assert_eq!(gens_n.n, 1); + GroupElement::vartime_multiscalar_mul( + [*self, *blind].to_vec(), + [gens_n.G[0], gens_n.h].to_vec(), + ) + } +} + +impl Commitments for Vec { + fn commit(&self, blind: &Scalar, gens_n: &MultiCommitGens) -> GroupElement { + assert_eq!(gens_n.n, self.len()); + GroupElement::vartime_multiscalar_mul((*self).clone(), gens_n.G.clone()) + blind * gens_n.h + } +} + +impl Commitments for [Scalar] { + fn commit(&self, blind: &Scalar, gens_n: &MultiCommitGens) -> GroupElement { + assert_eq!(gens_n.n, self.len()); + GroupElement::vartime_multiscalar_mul(self.to_vec(), gens_n.G.clone()) + blind * gens_n.h + } +} diff --git a/packages/Spartan-secq/src/dense_mlpoly.rs b/packages/Spartan-secq/src/dense_mlpoly.rs new file mode 100644 index 0000000..6977304 --- /dev/null +++ b/packages/Spartan-secq/src/dense_mlpoly.rs @@ -0,0 +1,602 @@ +#![allow(clippy::too_many_arguments)] +use super::commitments::{Commitments, MultiCommitGens}; +use super::errors::ProofVerifyError; +use super::group::{CompressedGroup, GroupElement, VartimeMultiscalarMul}; +use super::math::Math; +use super::nizk::{DotProductProofGens, DotProductProofLog}; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use crate::group::DecompressEncodedPoint; +use core::ops::Index; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "multicore")] +use rayon::prelude::*; + +#[derive(Debug)] +pub struct DensePolynomial { + num_vars: usize, // the number of variables in the multilinear polynomial + len: usize, + Z: Vec, // evaluations of the polynomial in all the 2^num_vars Boolean inputs +} + +pub struct PolyCommitmentGens { + pub gens: DotProductProofGens, +} + +impl PolyCommitmentGens { + // the number of variables in the multilinear polynomial + pub fn new(num_vars: usize, label: &'static [u8]) -> PolyCommitmentGens { + let (_left, right) = EqPolynomial::compute_factored_lens(num_vars); + let gens = DotProductProofGens::new(right.pow2(), label); + PolyCommitmentGens { gens } + } +} + +pub struct PolyCommitmentBlinds { + blinds: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PolyCommitment { + C: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ConstPolyCommitment { + C: CompressedGroup, +} + +pub struct EqPolynomial { + r: Vec, +} + +impl EqPolynomial { + pub fn new(r: Vec) -> Self { + EqPolynomial { r } + } + + pub fn evaluate(&self, rx: &[Scalar]) -> Scalar { + assert_eq!(self.r.len(), rx.len()); + (0..rx.len()) + .map(|i| self.r[i] * rx[i] + (Scalar::one() - self.r[i]) * (Scalar::one() - rx[i])) + .product() + } + + pub fn evals(&self) -> Vec { + let ell = self.r.len(); + + let mut evals: Vec = vec![Scalar::one(); ell.pow2()]; + let mut size = 1; + for j in 0..ell { + // in each iteration, we double the size of chis + size *= 2; + for i in (0..size).rev().step_by(2) { + // copy each element from the prior iteration twice + let scalar = evals[i / 2]; + evals[i] = scalar * self.r[j]; + evals[i - 1] = scalar - evals[i]; + } + } + evals + } + + pub fn compute_factored_lens(ell: usize) -> (usize, usize) { + (ell / 2, ell - ell / 2) + } + + pub fn compute_factored_evals(&self) -> (Vec, Vec) { + let ell = self.r.len(); + let (left_num_vars, _right_num_vars) = EqPolynomial::compute_factored_lens(ell); + + let L = EqPolynomial::new(self.r[..left_num_vars].to_vec()).evals(); + let R = EqPolynomial::new(self.r[left_num_vars..ell].to_vec()).evals(); + + (L, R) + } +} + +pub struct IdentityPolynomial { + size_point: usize, +} + +impl IdentityPolynomial { + pub fn new(size_point: usize) -> Self { + IdentityPolynomial { size_point } + } + + pub fn evaluate(&self, r: &[Scalar]) -> Scalar { + let len = r.len(); + assert_eq!(len, self.size_point); + (0..len) + .map(|i| Scalar::from((len - i - 1).pow2() as u64) * r[i]) + .sum() + } +} + +impl DensePolynomial { + pub fn new(Z: Vec) -> Self { + DensePolynomial { + num_vars: Z.len().log_2(), + len: Z.len(), + Z, + } + } + + pub fn get_num_vars(&self) -> usize { + self.num_vars + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn clone(&self) -> DensePolynomial { + DensePolynomial::new(self.Z[0..self.len].to_vec()) + } + + pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) { + assert!(idx < self.len()); + ( + DensePolynomial::new(self.Z[..idx].to_vec()), + DensePolynomial::new(self.Z[idx..2 * idx].to_vec()), + ) + } + + #[cfg(feature = "multicore")] + fn commit_inner(&self, blinds: &[Scalar], gens: &MultiCommitGens) -> PolyCommitment { + let L_size = blinds.len(); + let R_size = self.Z.len() / L_size; + assert_eq!(L_size * R_size, self.Z.len()); + let C = (0..L_size) + .into_par_iter() + .map(|i| { + self.Z[R_size * i..R_size * (i + 1)] + .commit(&blinds[i], gens) + .compress() + }) + .collect(); + PolyCommitment { C } + } + + #[cfg(not(feature = "multicore"))] + fn commit_inner(&self, blinds: &[Scalar], gens: &MultiCommitGens) -> PolyCommitment { + let L_size = blinds.len(); + let R_size = self.Z.len() / L_size; + assert_eq!(L_size * R_size, self.Z.len()); + let C = (0..L_size) + .map(|i| { + self.Z[R_size * i..R_size * (i + 1)] + .commit(&blinds[i], gens) + .compress() + }) + .collect(); + PolyCommitment { C } + } + + pub fn commit( + &self, + gens: &PolyCommitmentGens, + random_tape: Option<&mut RandomTape>, + ) -> (PolyCommitment, PolyCommitmentBlinds) { + let n = self.Z.len(); + let ell = self.get_num_vars(); + assert_eq!(n, ell.pow2()); + + let (left_num_vars, right_num_vars) = EqPolynomial::compute_factored_lens(ell); + let L_size = left_num_vars.pow2(); + let R_size = right_num_vars.pow2(); + assert_eq!(L_size * R_size, n); + + let blinds = if let Some(t) = random_tape { + PolyCommitmentBlinds { + blinds: t.random_vector(b"poly_blinds", L_size), + } + } else { + PolyCommitmentBlinds { + blinds: vec![Scalar::zero(); L_size], + } + }; + + (self.commit_inner(&blinds.blinds, &gens.gens.gens_n), blinds) + } + + pub fn bound(&self, L: &[Scalar]) -> Vec { + let (left_num_vars, right_num_vars) = EqPolynomial::compute_factored_lens(self.get_num_vars()); + let L_size = left_num_vars.pow2(); + let R_size = right_num_vars.pow2(); + (0..R_size) + .map(|i| (0..L_size).map(|j| L[j] * self.Z[j * R_size + i]).sum()) + .collect() + } + + pub fn bound_poly_var_top(&mut self, r: &Scalar) { + let n = self.len() / 2; + for i in 0..n { + self.Z[i] = self.Z[i] + r * (self.Z[i + n] - self.Z[i]); + } + self.num_vars -= 1; + self.len = n; + } + + pub fn bound_poly_var_bot(&mut self, r: &Scalar) { + let n = self.len() / 2; + for i in 0..n { + self.Z[i] = self.Z[2 * i] + r * (self.Z[2 * i + 1] - self.Z[2 * i]); + } + self.num_vars -= 1; + self.len = n; + } + + // returns Z(r) in O(n) time + pub fn evaluate(&self, r: &[Scalar]) -> Scalar { + // r must have a value for each variable + assert_eq!(r.len(), self.get_num_vars()); + let chis = EqPolynomial::new(r.to_vec()).evals(); + assert_eq!(chis.len(), self.Z.len()); + DotProductProofLog::compute_dotproduct(&self.Z, &chis) + } + + fn vec(&self) -> &Vec { + &self.Z + } + + pub fn extend(&mut self, other: &DensePolynomial) { + // TODO: allow extension even when some vars are bound + assert_eq!(self.Z.len(), self.len); + let other_vec = other.vec(); + assert_eq!(other_vec.len(), self.len); + self.Z.extend(other_vec); + self.num_vars += 1; + self.len *= 2; + assert_eq!(self.Z.len(), self.len); + } + + pub fn merge<'a, I>(polys: I) -> DensePolynomial + where + I: IntoIterator, + { + let mut Z: Vec = Vec::new(); + for poly in polys.into_iter() { + Z.extend(poly.vec()); + } + + // pad the polynomial with zero polynomial at the end + Z.resize(Z.len().next_power_of_two(), Scalar::zero()); + + DensePolynomial::new(Z) + } + + pub fn from_usize(Z: &[usize]) -> Self { + DensePolynomial::new( + (0..Z.len()) + .map(|i| Scalar::from(Z[i] as u64)) + .collect::>(), + ) + } +} + +impl Index for DensePolynomial { + type Output = Scalar; + + #[inline(always)] + fn index(&self, _index: usize) -> &Scalar { + &(self.Z[_index]) + } +} + +impl AppendToTranscript for PolyCommitment { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_message(label, b"poly_commitment_begin"); + for i in 0..self.C.len() { + transcript.append_point(b"poly_commitment_share", &self.C[i]); + } + transcript.append_message(label, b"poly_commitment_end"); + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PolyEvalProof { + proof: DotProductProofLog, +} + +impl PolyEvalProof { + fn protocol_name() -> &'static [u8] { + b"polynomial evaluation proof" + } + + pub fn prove( + poly: &DensePolynomial, + blinds_opt: Option<&PolyCommitmentBlinds>, + r: &[Scalar], // point at which the polynomial is evaluated + Zr: &Scalar, // evaluation of \widetilde{Z}(r) + blind_Zr_opt: Option<&Scalar>, // specifies a blind for Zr + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (PolyEvalProof, CompressedGroup) { + transcript.append_protocol_name(PolyEvalProof::protocol_name()); + + // assert vectors are of the right size + assert_eq!(poly.get_num_vars(), r.len()); + + let (left_num_vars, right_num_vars) = EqPolynomial::compute_factored_lens(r.len()); + let L_size = left_num_vars.pow2(); + let R_size = right_num_vars.pow2(); + + let default_blinds = PolyCommitmentBlinds { + blinds: vec![Scalar::zero(); L_size], + }; + let blinds = blinds_opt.map_or(&default_blinds, |p| p); + + assert_eq!(blinds.blinds.len(), L_size); + + let zero = Scalar::zero(); + let blind_Zr = blind_Zr_opt.map_or(&zero, |p| p); + + // compute the L and R vectors + let eq = EqPolynomial::new(r.to_vec()); + let (L, R) = eq.compute_factored_evals(); + assert_eq!(L.len(), L_size); + assert_eq!(R.len(), R_size); + + // compute the vector underneath L*Z and the L*blinds + // compute vector-matrix product between L and Z viewed as a matrix + let LZ = poly.bound(&L); + let LZ_blind: Scalar = (0..L.len()).map(|i| blinds.blinds[i] * L[i]).sum(); + + // a dot product proof of size R_size + let (proof, _C_LR, C_Zr_prime) = DotProductProofLog::prove( + &gens.gens, + transcript, + random_tape, + &LZ, + &LZ_blind, + &R, + Zr, + blind_Zr, + ); + + (PolyEvalProof { proof }, C_Zr_prime) + } + + pub fn verify( + &self, + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + r: &[Scalar], // point at which the polynomial is evaluated + C_Zr: &CompressedGroup, // commitment to \widetilde{Z}(r) + comm: &PolyCommitment, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(PolyEvalProof::protocol_name()); + + // compute L and R + let eq = EqPolynomial::new(r.to_vec()); + let (L, R) = eq.compute_factored_evals(); + + // compute a weighted sum of commitments and L + let C_decompressed = comm.C.iter().map(|pt| pt.decompress().unwrap()); + + let C_LZ = GroupElement::vartime_multiscalar_mul(L, C_decompressed.collect()).compress(); + + self + .proof + .verify(R.len(), &gens.gens, transcript, &R, &C_LZ, C_Zr) + } + + pub fn verify_plain( + &self, + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + r: &[Scalar], // point at which the polynomial is evaluated + Zr: &Scalar, // evaluation \widetilde{Z}(r) + comm: &PolyCommitment, + ) -> Result<(), ProofVerifyError> { + // compute a commitment to Zr with a blind of zero + let C_Zr = Zr.commit(&Scalar::zero(), &gens.gens.gens_1).compress(); + + self.verify(gens, transcript, r, &C_Zr, comm) + } +} + +#[cfg(test)] +mod tests { + use super::super::scalar::ScalarFromPrimitives; + use super::*; + use rand_core::OsRng; + + fn evaluate_with_LR(Z: &[Scalar], r: &[Scalar]) -> Scalar { + let eq = EqPolynomial::new(r.to_vec()); + let (L, R) = eq.compute_factored_evals(); + + let ell = r.len(); + // ensure ell is even + assert!(ell % 2 == 0); + // compute n = 2^\ell + let n = ell.pow2(); + // compute m = sqrt(n) = 2^{\ell/2} + let m = n.square_root(); + + // compute vector-matrix product between L and Z viewed as a matrix + let LZ = (0..m) + .map(|i| (0..m).map(|j| L[j] * Z[j * m + i]).sum()) + .collect::>(); + + // compute dot product between LZ and R + DotProductProofLog::compute_dotproduct(&LZ, &R) + } + + #[test] + fn check_polynomial_evaluation() { + // Z = [1, 2, 1, 4] + let Z = vec![ + Scalar::one(), + (2_usize).to_scalar(), + (1_usize).to_scalar(), + (4_usize).to_scalar(), + ]; + + // r = [4,3] + let r = vec![(4_usize).to_scalar(), (3_usize).to_scalar()]; + + let eval_with_LR = evaluate_with_LR(&Z, &r); + let poly = DensePolynomial::new(Z); + + let eval = poly.evaluate(&r); + assert_eq!(eval, (28_usize).to_scalar()); + assert_eq!(eval_with_LR, eval); + } + + pub fn compute_factored_chis_at_r(r: &[Scalar]) -> (Vec, Vec) { + let mut L: Vec = Vec::new(); + let mut R: Vec = Vec::new(); + + let ell = r.len(); + assert!(ell % 2 == 0); // ensure ell is even + let n = ell.pow2(); + let m = n.square_root(); + + // compute row vector L + for i in 0..m { + let mut chi_i = Scalar::one(); + for j in 0..ell / 2 { + let bit_j = ((m * i) & (1 << (r.len() - j - 1))) > 0; + if bit_j { + chi_i *= r[j]; + } else { + chi_i *= Scalar::one() - r[j]; + } + } + L.push(chi_i); + } + + // compute column vector R + for i in 0..m { + let mut chi_i = Scalar::one(); + for j in ell / 2..ell { + let bit_j = (i & (1 << (r.len() - j - 1))) > 0; + if bit_j { + chi_i *= r[j]; + } else { + chi_i *= Scalar::one() - r[j]; + } + } + R.push(chi_i); + } + (L, R) + } + + pub fn compute_chis_at_r(r: &[Scalar]) -> Vec { + let ell = r.len(); + let n = ell.pow2(); + let mut chis: Vec = Vec::new(); + for i in 0..n { + let mut chi_i = Scalar::one(); + for j in 0..r.len() { + let bit_j = (i & (1 << (r.len() - j - 1))) > 0; + if bit_j { + chi_i *= r[j]; + } else { + chi_i *= Scalar::one() - r[j]; + } + } + chis.push(chi_i); + } + chis + } + + pub fn compute_outerproduct(L: Vec, R: Vec) -> Vec { + assert_eq!(L.len(), R.len()); + (0..L.len()) + .map(|i| (0..R.len()).map(|j| L[i] * R[j]).collect::>()) + .collect::>>() + .into_iter() + .flatten() + .collect::>() + } + + #[test] + fn check_memoized_chis() { + let mut csprng: OsRng = OsRng; + + let s = 10; + let mut r: Vec = Vec::new(); + for _i in 0..s { + r.push(Scalar::random(&mut csprng)); + } + let chis = tests::compute_chis_at_r(&r); + let chis_m = EqPolynomial::new(r).evals(); + assert_eq!(chis, chis_m); + } + + #[test] + fn check_factored_chis() { + let mut csprng: OsRng = OsRng; + + let s = 10; + let mut r: Vec = Vec::new(); + for _i in 0..s { + r.push(Scalar::random(&mut csprng)); + } + let chis = EqPolynomial::new(r.clone()).evals(); + let (L, R) = EqPolynomial::new(r).compute_factored_evals(); + let O = compute_outerproduct(L, R); + assert_eq!(chis, O); + } + + #[test] + fn check_memoized_factored_chis() { + let mut csprng: OsRng = OsRng; + + let s = 10; + let mut r: Vec = Vec::new(); + for _i in 0..s { + r.push(Scalar::random(&mut csprng)); + } + let (L, R) = tests::compute_factored_chis_at_r(&r); + let eq = EqPolynomial::new(r); + let (L2, R2) = eq.compute_factored_evals(); + assert_eq!(L, L2); + assert_eq!(R, R2); + } + + #[test] + fn check_polynomial_commit() { + let Z = vec![ + (1_usize).to_scalar(), + (2_usize).to_scalar(), + (1_usize).to_scalar(), + (4_usize).to_scalar(), + ]; + let poly = DensePolynomial::new(Z); + + // r = [4,3] + let r = vec![(4_usize).to_scalar(), (3_usize).to_scalar()]; + let eval = poly.evaluate(&r); + assert_eq!(eval, (28_usize).to_scalar()); + + let gens = PolyCommitmentGens::new(poly.get_num_vars(), b"test-two"); + let (poly_commitment, blinds) = poly.commit(&gens, None); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, C_Zr) = PolyEvalProof::prove( + &poly, + Some(&blinds), + &r, + &eval, + None, + &gens, + &mut prover_transcript, + &mut random_tape, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + + assert!(proof + .verify(&gens, &mut verifier_transcript, &r, &C_Zr, &poly_commitment) + .is_ok()); + } +} diff --git a/packages/Spartan-secq/src/errors.rs b/packages/Spartan-secq/src/errors.rs new file mode 100644 index 0000000..18e157b --- /dev/null +++ b/packages/Spartan-secq/src/errors.rs @@ -0,0 +1,32 @@ +use core::fmt::Debug; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ProofVerifyError { + #[error("Proof verification failed")] + InternalError, + #[error("Compressed group element failed to decompress: {0:?}")] + DecompressionError([u8; 32]), +} + +impl Default for ProofVerifyError { + fn default() -> Self { + ProofVerifyError::InternalError + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum R1CSError { + /// returned if the number of constraints is not a power of 2 + NonPowerOfTwoCons, + /// returned if the number of variables is not a power of 2 + NonPowerOfTwoVars, + /// returned if a wrong number of inputs in an assignment are supplied + InvalidNumberOfInputs, + /// returned if a wrong number of variables in an assignment are supplied + InvalidNumberOfVars, + /// returned if a [u8;32] does not parse into a valid Scalar in the field of secq256k1 + InvalidScalar, + /// returned if the supplied row or col in (row,col,val) tuple is out of range + InvalidIndex, +} diff --git a/packages/Spartan-secq/src/group.rs b/packages/Spartan-secq/src/group.rs new file mode 100644 index 0000000..e14091f --- /dev/null +++ b/packages/Spartan-secq/src/group.rs @@ -0,0 +1,138 @@ +use secq256k1::{AffinePoint, ProjectivePoint}; + +use super::errors::ProofVerifyError; +use super::scalar::{Scalar, ScalarBytes, ScalarBytesFromScalar}; +use core::ops::{Mul, MulAssign}; +use multiexp::multiexp; + +pub type GroupElement = secq256k1::AffinePoint; +pub type CompressedGroup = secq256k1::EncodedPoint; +pub trait CompressedGroupExt { + type Group; + fn unpack(&self) -> Result; +} + +impl CompressedGroupExt for CompressedGroup { + type Group = secq256k1::AffinePoint; + fn unpack(&self) -> Result { + let result = AffinePoint::decompress(*self); + if result.is_some().into() { + return Ok(result.unwrap()); + } else { + Err(ProofVerifyError::DecompressionError( + (*self.to_bytes()).try_into().unwrap(), + )) + } + } +} + +pub trait DecompressEncodedPoint { + fn decompress(&self) -> Option; +} + +impl DecompressEncodedPoint for CompressedGroup { + fn decompress(&self) -> Option { + Some(self.unpack().unwrap()) + } +} + +impl<'b> MulAssign<&'b Scalar> for GroupElement { + fn mul_assign(&mut self, scalar: &'b Scalar) { + let result = (self as &GroupElement) * Scalar::decompress_scalar(scalar); + *self = result; + } +} + +impl<'a, 'b> Mul<&'b Scalar> for &'a GroupElement { + type Output = GroupElement; + fn mul(self, scalar: &'b Scalar) -> GroupElement { + *self * Scalar::decompress_scalar(scalar) + } +} + +impl<'a, 'b> Mul<&'b GroupElement> for &'a Scalar { + type Output = GroupElement; + + fn mul(self, point: &'b GroupElement) -> GroupElement { + (*point * Scalar::decompress_scalar(self)).into() + } +} + +macro_rules! define_mul_variants { + (LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => { + impl<'b> Mul<&'b $rhs> for $lhs { + type Output = $out; + fn mul(self, rhs: &'b $rhs) -> $out { + &self * rhs + } + } + + impl<'a> Mul<$rhs> for &'a $lhs { + type Output = $out; + fn mul(self, rhs: $rhs) -> $out { + self * &rhs + } + } + + impl Mul<$rhs> for $lhs { + type Output = $out; + fn mul(self, rhs: $rhs) -> $out { + &self * &rhs + } + } + }; +} + +macro_rules! define_mul_assign_variants { + (LHS = $lhs:ty, RHS = $rhs:ty) => { + impl MulAssign<$rhs> for $lhs { + fn mul_assign(&mut self, rhs: $rhs) { + *self *= &rhs; + } + } + }; +} + +define_mul_assign_variants!(LHS = GroupElement, RHS = Scalar); +define_mul_variants!(LHS = GroupElement, RHS = Scalar, Output = GroupElement); +define_mul_variants!(LHS = Scalar, RHS = GroupElement, Output = GroupElement); + +pub trait VartimeMultiscalarMul { + type Scalar; + fn vartime_multiscalar_mul(scalars: Vec, points: Vec) -> Self; +} + +impl VartimeMultiscalarMul for GroupElement { + type Scalar = super::scalar::Scalar; + // TODO Borrow the arguments so we don't have to clone them, as it was in the original implementation + fn vartime_multiscalar_mul(scalars: Vec, points: Vec) -> Self { + let points: Vec = points.iter().map(|p| ProjectivePoint::from(p.0)).collect(); + + let pairs: Vec<(ScalarBytes, ProjectivePoint)> = scalars + .into_iter() + .enumerate() + .map(|(i, s)| (Scalar::decompress_scalar(&s), points[i])) + .collect(); + + let result = multiexp::(pairs.as_slice()); + + AffinePoint(result.to_affine()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn msm() { + let scalars = vec![Scalar::from(1), Scalar::from(2), Scalar::from(3)]; + let points = vec![ + GroupElement::generator(), + GroupElement::generator(), + GroupElement::generator(), + ]; + let result = GroupElement::vartime_multiscalar_mul(scalars, points); + + assert_eq!(result, GroupElement::generator() * Scalar::from(6)); + } +} diff --git a/packages/Spartan-secq/src/lib.rs b/packages/Spartan-secq/src/lib.rs new file mode 100644 index 0000000..293ff43 --- /dev/null +++ b/packages/Spartan-secq/src/lib.rs @@ -0,0 +1,751 @@ +#![allow(non_snake_case)] +#![doc = include_str!("../README.md")] +#![deny(missing_docs)] +#![allow(clippy::assertions_on_result_states)] + +extern crate byteorder; +extern crate core; +extern crate digest; +extern crate merlin; +extern crate rand; +extern crate sha3; + +#[cfg(feature = "multicore")] +extern crate rayon; + +mod commitments; +mod dense_mlpoly; +mod errors; +mod group; +mod math; +mod nizk; +mod product_tree; +mod r1csinstance; +mod r1csproof; +mod random; +mod scalar; +mod sparse_mlpoly; +mod sumcheck; +mod timer; +mod transcript; +mod unipoly; + +use core::cmp::max; +use errors::{ProofVerifyError, R1CSError}; +use merlin::Transcript; +use r1csinstance::{ + R1CSCommitment, R1CSCommitmentGens, R1CSDecommitment, R1CSEvalProof, R1CSInstance, +}; +use r1csproof::{R1CSGens, R1CSProof}; +use random::RandomTape; +use scalar::Scalar; +use serde::{Deserialize, Serialize}; +use timer::Timer; +use transcript::{AppendToTranscript, ProofTranscript}; + +/// `ComputationCommitment` holds a public preprocessed NP statement (e.g., R1CS) +pub struct ComputationCommitment { + comm: R1CSCommitment, +} + +/// `ComputationDecommitment` holds information to decommit `ComputationCommitment` +pub struct ComputationDecommitment { + decomm: R1CSDecommitment, +} + +/// `Assignment` holds an assignment of values to either the inputs or variables in an `Instance` +#[derive(Serialize, Deserialize, Clone)] +pub struct Assignment { + assignment: Vec, +} + +impl Assignment { + /// Constructs a new `Assignment` from a vector + pub fn new(assignment: &[[u8; 32]]) -> Result { + let bytes_to_scalar = |vec: &[[u8; 32]]| -> Result, R1CSError> { + let mut vec_scalar: Vec = Vec::new(); + for v in vec { + let val = Scalar::from_bytes(v); + if val.is_some().unwrap_u8() == 1 { + vec_scalar.push(val.unwrap()); + } else { + return Err(R1CSError::InvalidScalar); + } + } + Ok(vec_scalar) + }; + + let assignment_scalar = bytes_to_scalar(assignment); + + // check for any parsing errors + if assignment_scalar.is_err() { + return Err(R1CSError::InvalidScalar); + } + + Ok(Assignment { + assignment: assignment_scalar.unwrap(), + }) + } + + /// pads Assignment to the specified length + fn pad(&self, len: usize) -> VarsAssignment { + // check that the new length is higher than current length + assert!(len > self.assignment.len()); + + let padded_assignment = { + let mut padded_assignment = self.assignment.clone(); + padded_assignment.extend(vec![Scalar::zero(); len - self.assignment.len()]); + padded_assignment + }; + + VarsAssignment { + assignment: padded_assignment, + } + } +} + +/// `VarsAssignment` holds an assignment of values to variables in an `Instance` +pub type VarsAssignment = Assignment; + +/// `InputsAssignment` holds an assignment of values to variables in an `Instance` +pub type InputsAssignment = Assignment; + +/// `Instance` holds the description of R1CS matrices and a hash of the matrices +#[derive(Serialize, Deserialize)] +pub struct Instance { + /// R1CS instance + pub inst: R1CSInstance, + digest: Vec, +} + +impl Instance { + /// Constructs a new `Instance` and an associated satisfying assignment + pub fn new( + num_cons: usize, + num_vars: usize, + num_inputs: usize, + A: &[(usize, usize, [u8; 32])], + B: &[(usize, usize, [u8; 32])], + C: &[(usize, usize, [u8; 32])], + ) -> Result { + let (num_vars_padded, num_cons_padded) = { + let num_vars_padded = { + let mut num_vars_padded = num_vars; + + // ensure that num_inputs + 1 <= num_vars + num_vars_padded = max(num_vars_padded, num_inputs + 1); + + // ensure that num_vars_padded a power of two + if num_vars_padded.next_power_of_two() != num_vars_padded { + num_vars_padded = num_vars_padded.next_power_of_two(); + } + num_vars_padded + }; + + let num_cons_padded = { + let mut num_cons_padded = num_cons; + + // ensure that num_cons_padded is at least 2 + if num_cons_padded == 0 || num_cons_padded == 1 { + num_cons_padded = 2; + } + + // ensure that num_cons_padded is power of 2 + if num_cons.next_power_of_two() != num_cons { + num_cons_padded = num_cons.next_power_of_two(); + } + num_cons_padded + }; + + (num_vars_padded, num_cons_padded) + }; + + let bytes_to_scalar = + |tups: &[(usize, usize, [u8; 32])]| -> Result, R1CSError> { + let mut mat: Vec<(usize, usize, Scalar)> = Vec::new(); + for &(row, col, val_bytes) in tups { + // row must be smaller than num_cons + if row >= num_cons { + return Err(R1CSError::InvalidIndex); + } + + // col must be smaller than num_vars + 1 + num_inputs + if col >= num_vars + 1 + num_inputs { + return Err(R1CSError::InvalidIndex); + } + + let val = Scalar::from_bytes(&val_bytes); + if val.is_some().unwrap_u8() == 1 { + // if col >= num_vars, it means that it is referencing a 1 or input in the satisfying + // assignment + if col >= num_vars { + mat.push((row, col + num_vars_padded - num_vars, val.unwrap())); + } else { + mat.push((row, col, val.unwrap())); + } + } else { + return Err(R1CSError::InvalidScalar); + } + } + + // pad with additional constraints up until num_cons_padded if the original constraints were 0 or 1 + // we do not need to pad otherwise because the dummy constraints are implicit in the sum-check protocol + if num_cons == 0 || num_cons == 1 { + for i in tups.len()..num_cons_padded { + mat.push((i, num_vars, Scalar::zero())); + } + } + + Ok(mat) + }; + + let A_scalar = bytes_to_scalar(A); + if A_scalar.is_err() { + return Err(A_scalar.err().unwrap()); + } + + let B_scalar = bytes_to_scalar(B); + if B_scalar.is_err() { + return Err(B_scalar.err().unwrap()); + } + + let C_scalar = bytes_to_scalar(C); + if C_scalar.is_err() { + return Err(C_scalar.err().unwrap()); + } + + let inst = R1CSInstance::new( + num_cons_padded, + num_vars_padded, + num_inputs, + &A_scalar.unwrap(), + &B_scalar.unwrap(), + &C_scalar.unwrap(), + ); + + let digest = inst.get_digest(); + + Ok(Instance { inst, digest }) + } + + /// Checks if a given R1CSInstance is satisfiable with a given variables and inputs assignments + pub fn is_sat( + &self, + vars: &VarsAssignment, + inputs: &InputsAssignment, + ) -> Result { + if vars.assignment.len() > self.inst.get_num_vars() { + return Err(R1CSError::InvalidNumberOfInputs); + } + + if inputs.assignment.len() != self.inst.get_num_inputs() { + return Err(R1CSError::InvalidNumberOfInputs); + } + + // we might need to pad variables + let padded_vars = { + let num_padded_vars = self.inst.get_num_vars(); + let num_vars = vars.assignment.len(); + if num_padded_vars > num_vars { + vars.pad(num_padded_vars) + } else { + vars.clone() + } + }; + + Ok( + self + .inst + .is_sat(&padded_vars.assignment, &inputs.assignment), + ) + } + + /// Constructs a new synthetic R1CS `Instance` and an associated satisfying assignment + pub fn produce_synthetic_r1cs( + num_cons: usize, + num_vars: usize, + num_inputs: usize, + ) -> (Instance, VarsAssignment, InputsAssignment) { + let (inst, vars, inputs) = R1CSInstance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + let digest = inst.get_digest(); + ( + Instance { inst, digest }, + VarsAssignment { assignment: vars }, + InputsAssignment { assignment: inputs }, + ) + } +} + +/// `SNARKGens` holds public parameters for producing and verifying proofs with the Spartan SNARK +pub struct SNARKGens { + gens_r1cs_sat: R1CSGens, + gens_r1cs_eval: R1CSCommitmentGens, +} + +impl SNARKGens { + /// Constructs a new `SNARKGens` given the size of the R1CS statement + /// `num_nz_entries` specifies the maximum number of non-zero entries in any of the three R1CS matrices + pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize, num_nz_entries: usize) -> Self { + let num_vars_padded = { + let mut num_vars_padded = max(num_vars, num_inputs + 1); + if num_vars_padded != num_vars_padded.next_power_of_two() { + num_vars_padded = num_vars_padded.next_power_of_two(); + } + num_vars_padded + }; + + let gens_r1cs_sat = R1CSGens::new(b"gens_r1cs_sat", num_cons, num_vars_padded); + let gens_r1cs_eval = R1CSCommitmentGens::new( + b"gens_r1cs_eval", + num_cons, + num_vars_padded, + num_inputs, + num_nz_entries, + ); + SNARKGens { + gens_r1cs_sat, + gens_r1cs_eval, + } + } +} + +/// `SNARK` holds a proof produced by Spartan SNARK +#[derive(Serialize, Deserialize, Debug)] +pub struct SNARK { + r1cs_sat_proof: R1CSProof, + inst_evals: (Scalar, Scalar, Scalar), + r1cs_eval_proof: R1CSEvalProof, +} + +impl SNARK { + fn protocol_name() -> &'static [u8] { + b"Spartan SNARK proof" + } + + /// A public computation to create a commitment to an R1CS instance + pub fn encode( + inst: &Instance, + gens: &SNARKGens, + ) -> (ComputationCommitment, ComputationDecommitment) { + let timer_encode = Timer::new("SNARK::encode"); + let (comm, decomm) = inst.inst.commit(&gens.gens_r1cs_eval); + timer_encode.stop(); + ( + ComputationCommitment { comm }, + ComputationDecommitment { decomm }, + ) + } + + /// A method to produce a SNARK proof of the satisfiability of an R1CS instance + pub fn prove( + inst: &Instance, + comm: &ComputationCommitment, + decomm: &ComputationDecommitment, + vars: VarsAssignment, + inputs: &InputsAssignment, + gens: &SNARKGens, + transcript: &mut Transcript, + ) -> Self { + let timer_prove = Timer::new("SNARK::prove"); + + // we create a Transcript object seeded with a random Scalar + // to aid the prover produce its randomness + let mut random_tape = RandomTape::new(b"proof"); + + transcript.append_protocol_name(SNARK::protocol_name()); + comm.comm.append_to_transcript(b"comm", transcript); + + let (r1cs_sat_proof, rx, ry) = { + let (proof, rx, ry) = { + // we might need to pad variables + let padded_vars = { + let num_padded_vars = inst.inst.get_num_vars(); + let num_vars = vars.assignment.len(); + if num_padded_vars > num_vars { + vars.pad(num_padded_vars) + } else { + vars + } + }; + + R1CSProof::prove( + &inst.inst, + padded_vars.assignment, + &inputs.assignment, + &gens.gens_r1cs_sat, + transcript, + &mut random_tape, + ) + }; + + let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len())); + + (proof, rx, ry) + }; + + // We send evaluations of A, B, C at r = (rx, ry) as claims + // to enable the verifier complete the first sum-check + let timer_eval = Timer::new("eval_sparse_polys"); + let inst_evals = { + let (Ar, Br, Cr) = inst.inst.evaluate(&rx, &ry); + Ar.append_to_transcript(b"Ar_claim", transcript); + Br.append_to_transcript(b"Br_claim", transcript); + Cr.append_to_transcript(b"Cr_claim", transcript); + (Ar, Br, Cr) + }; + timer_eval.stop(); + + let r1cs_eval_proof = { + let proof = R1CSEvalProof::prove( + &decomm.decomm, + &rx, + &ry, + &inst_evals, + &gens.gens_r1cs_eval, + transcript, + &mut random_tape, + ); + + let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + Timer::print(&format!("len_r1cs_eval_proof {:?}", proof_encoded.len())); + proof + }; + + timer_prove.stop(); + SNARK { + r1cs_sat_proof, + inst_evals, + r1cs_eval_proof, + } + } + + /// A method to verify the SNARK proof of the satisfiability of an R1CS instance + pub fn verify( + &self, + comm: &ComputationCommitment, + input: &InputsAssignment, + transcript: &mut Transcript, + gens: &SNARKGens, + ) -> Result<(), ProofVerifyError> { + let timer_verify = Timer::new("SNARK::verify"); + transcript.append_protocol_name(SNARK::protocol_name()); + + // append a commitment to the computation to the transcript + comm.comm.append_to_transcript(b"comm", transcript); + + let timer_sat_proof = Timer::new("verify_sat_proof"); + assert_eq!(input.assignment.len(), comm.comm.get_num_inputs()); + let (rx, ry) = self.r1cs_sat_proof.verify( + comm.comm.get_num_vars(), + comm.comm.get_num_cons(), + &input.assignment, + &self.inst_evals, + transcript, + &gens.gens_r1cs_sat, + )?; + timer_sat_proof.stop(); + + let timer_eval_proof = Timer::new("verify_eval_proof"); + let (Ar, Br, Cr) = &self.inst_evals; + Ar.append_to_transcript(b"Ar_claim", transcript); + Br.append_to_transcript(b"Br_claim", transcript); + Cr.append_to_transcript(b"Cr_claim", transcript); + self.r1cs_eval_proof.verify( + &comm.comm, + &rx, + &ry, + &self.inst_evals, + &gens.gens_r1cs_eval, + transcript, + )?; + timer_eval_proof.stop(); + timer_verify.stop(); + Ok(()) + } +} + +/// `NIZKGens` holds public parameters for producing and verifying proofs with the Spartan NIZK +pub struct NIZKGens { + gens_r1cs_sat: R1CSGens, +} + +impl NIZKGens { + /// Constructs a new `NIZKGens` given the size of the R1CS statement + pub fn new(num_cons: usize, num_vars: usize, num_inputs: usize) -> Self { + let num_vars_padded = { + let mut num_vars_padded = max(num_vars, num_inputs + 1); + if num_vars_padded != num_vars_padded.next_power_of_two() { + num_vars_padded = num_vars_padded.next_power_of_two(); + } + num_vars_padded + }; + + let gens_r1cs_sat = R1CSGens::new(b"gens_r1cs_sat", num_cons, num_vars_padded); + NIZKGens { gens_r1cs_sat } + } +} + +/// `NIZK` holds a proof produced by Spartan NIZK +#[derive(Serialize, Deserialize, Debug)] +pub struct NIZK { + r1cs_sat_proof: R1CSProof, + r: (Vec, Vec), +} + +impl NIZK { + fn protocol_name() -> &'static [u8] { + b"Spartan NIZK proof" + } + + /// A method to produce a NIZK proof of the satisfiability of an R1CS instance + pub fn prove( + inst: &Instance, + vars: VarsAssignment, + input: &InputsAssignment, + gens: &NIZKGens, + transcript: &mut Transcript, + ) -> Self { + let timer_prove = Timer::new("NIZK::prove"); + // we create a Transcript object seeded with a random Scalar + // to aid the prover produce its randomness + let mut random_tape = RandomTape::new(b"proof"); + + transcript.append_protocol_name(NIZK::protocol_name()); + transcript.append_message(b"R1CSInstanceDigest", &inst.digest); + + let (r1cs_sat_proof, rx, ry) = { + // we might need to pad variables + let padded_vars = { + let num_padded_vars = inst.inst.get_num_vars(); + let num_vars = vars.assignment.len(); + if num_padded_vars > num_vars { + vars.pad(num_padded_vars) + } else { + vars + } + }; + + let (proof, rx, ry) = R1CSProof::prove( + &inst.inst, + padded_vars.assignment, + &input.assignment, + &gens.gens_r1cs_sat, + transcript, + &mut random_tape, + ); + let proof_encoded: Vec = bincode::serialize(&proof).unwrap(); + Timer::print(&format!("len_r1cs_sat_proof {:?}", proof_encoded.len())); + (proof, rx, ry) + }; + + timer_prove.stop(); + NIZK { + r1cs_sat_proof, + r: (rx, ry), + } + } + + /// A method to verify a NIZK proof of the satisfiability of an R1CS instance + pub fn verify( + &self, + inst: &Instance, + input: &InputsAssignment, + transcript: &mut Transcript, + gens: &NIZKGens, + ) -> Result<(), ProofVerifyError> { + let timer_verify = Timer::new("NIZK::verify"); + + transcript.append_protocol_name(NIZK::protocol_name()); + transcript.append_message(b"R1CSInstanceDigest", &inst.digest); + + // We send evaluations of A, B, C at r = (rx, ry) as claims + // to enable the verifier complete the first sum-check + let timer_eval = Timer::new("eval_sparse_polys"); + let (claimed_rx, claimed_ry) = &self.r; + let inst_evals = inst.inst.evaluate(claimed_rx, claimed_ry); + timer_eval.stop(); + + let timer_sat_proof = Timer::new("verify_sat_proof"); + assert_eq!(input.assignment.len(), inst.inst.get_num_inputs()); + let (rx, ry) = self.r1cs_sat_proof.verify( + inst.inst.get_num_vars(), + inst.inst.get_num_cons(), + &input.assignment, + &inst_evals, + transcript, + &gens.gens_r1cs_sat, + )?; + + // verify if claimed rx and ry are correct + assert_eq!(rx, *claimed_rx); + assert_eq!(ry, *claimed_ry); + timer_sat_proof.stop(); + timer_verify.stop(); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn check_snark() { + let num_vars = 256; + let num_cons = num_vars; + let num_inputs = 10; + + // produce public generators + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_cons); + + // produce a synthetic R1CSInstance + let (inst, vars, inputs) = Instance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + // create a commitment to R1CSInstance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a proof + let mut prover_transcript = Transcript::new(b"example"); + let proof = SNARK::prove( + &inst, + &comm, + &decomm, + vars, + &inputs, + &gens, + &mut prover_transcript, + ); + + // verify the proof + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(&comm, &inputs, &mut verifier_transcript, &gens) + .is_ok()); + } + + #[test] + pub fn check_r1cs_invalid_index() { + let num_cons = 4; + let num_vars = 8; + let num_inputs = 1; + + let zero: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, + ]; + + let A = vec![(0, 0, zero)]; + let B = vec![(100, 1, zero)]; + let C = vec![(1, 1, zero)]; + + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C); + assert!(inst.is_err()); + assert_eq!(inst.err(), Some(R1CSError::InvalidIndex)); + } + + #[test] + pub fn check_r1cs_invalid_scalar() { + let num_cons = 4; + let num_vars = 8; + let num_inputs = 1; + + let zero: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, + ]; + + let larger_than_mod = [255; 32]; + + let A = vec![(0, 0, zero)]; + let B = vec![(1, 1, larger_than_mod)]; + let C = vec![(1, 1, zero)]; + + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C); + assert!(inst.is_err()); + assert_eq!(inst.err(), Some(R1CSError::InvalidScalar)); + } + + #[test] + fn test_padded_constraints() { + // parameters of the R1CS instance + let num_cons = 1; + let num_vars = 0; + let num_inputs = 3; + let num_non_zero_entries = 3; + + // We will encode the above constraints into three matrices, where + // the coefficients in the matrix are in the little-endian byte order + let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new(); + let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new(); + let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new(); + + // Create a^2 + b + 13 + A.push((0, num_vars + 2, Scalar::one().to_bytes())); // 1*a + B.push((0, num_vars + 2, Scalar::one().to_bytes())); // 1*a + C.push((0, num_vars + 1, Scalar::one().to_bytes())); // 1*z + C.push((0, num_vars, (-Scalar::from(13u64)).to_bytes())); // -13*1 + C.push((0, num_vars + 3, (-Scalar::one()).to_bytes())); // -1*b + + // Var Assignments (Z_0 = 16 is the only output) + let vars = vec![Scalar::zero().to_bytes(); num_vars]; + + // create an InputsAssignment (a = 1, b = 2) + let mut inputs = vec![Scalar::zero().to_bytes(); num_inputs]; + inputs[0] = Scalar::from(16u64).to_bytes(); + inputs[1] = Scalar::from(1u64).to_bytes(); + inputs[2] = Scalar::from(2u64).to_bytes(); + + let assignment_inputs = InputsAssignment::new(&inputs).unwrap(); + let assignment_vars = VarsAssignment::new(&vars).unwrap(); + + // Check if instance is satisfiable + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C).unwrap(); + let res = inst.is_sat(&assignment_vars, &assignment_inputs); + assert!(res.unwrap(), "should be satisfied"); + + // SNARK public params + let gens = SNARKGens::new(num_cons, num_vars, num_inputs, num_non_zero_entries); + + // create a commitment to the R1CS instance + let (comm, decomm) = SNARK::encode(&inst, &gens); + + // produce a SNARK + let mut prover_transcript = Transcript::new(b"snark_example"); + let proof = SNARK::prove( + &inst, + &comm, + &decomm, + assignment_vars.clone(), + &assignment_inputs, + &gens, + &mut prover_transcript, + ); + + // verify the SNARK + let mut verifier_transcript = Transcript::new(b"snark_example"); + assert!(proof + .verify(&comm, &assignment_inputs, &mut verifier_transcript, &gens) + .is_ok()); + + // NIZK public params + let gens = NIZKGens::new(num_cons, num_vars, num_inputs); + + // produce a NIZK + let mut prover_transcript = Transcript::new(b"nizk_example"); + let proof = NIZK::prove( + &inst, + assignment_vars, + &assignment_inputs, + &gens, + &mut prover_transcript, + ); + + // verify the NIZK + let mut verifier_transcript = Transcript::new(b"nizk_example"); + assert!(proof + .verify(&inst, &assignment_inputs, &mut verifier_transcript, &gens) + .is_ok()); + } +} diff --git a/packages/Spartan-secq/src/math.rs b/packages/Spartan-secq/src/math.rs new file mode 100644 index 0000000..33e9e14 --- /dev/null +++ b/packages/Spartan-secq/src/math.rs @@ -0,0 +1,36 @@ +pub trait Math { + fn square_root(self) -> usize; + fn pow2(self) -> usize; + fn get_bits(self, num_bits: usize) -> Vec; + fn log_2(self) -> usize; +} + +impl Math for usize { + #[inline] + fn square_root(self) -> usize { + (self as f64).sqrt() as usize + } + + #[inline] + fn pow2(self) -> usize { + let base: usize = 2; + base.pow(self as u32) + } + + /// Returns the num_bits from n in a canonical order + fn get_bits(self, num_bits: usize) -> Vec { + (0..num_bits) + .map(|shift_amount| ((self & (1 << (num_bits - shift_amount - 1))) > 0)) + .collect::>() + } + + fn log_2(self) -> usize { + assert_ne!(self, 0); + + if self.is_power_of_two() { + (1usize.leading_zeros() - self.leading_zeros()) as usize + } else { + (0usize.leading_zeros() - self.leading_zeros()) as usize + } + } +} diff --git a/packages/Spartan-secq/src/nizk/bullet.rs b/packages/Spartan-secq/src/nizk/bullet.rs new file mode 100644 index 0000000..2767a21 --- /dev/null +++ b/packages/Spartan-secq/src/nizk/bullet.rs @@ -0,0 +1,267 @@ +//! This module is an adaptation of code from the bulletproofs crate. +//! See NOTICE.md for more details +#![allow(non_snake_case)] +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +use super::super::errors::ProofVerifyError; +use super::super::group::{CompressedGroup, GroupElement, VartimeMultiscalarMul}; +use super::super::math::Math; +use super::super::scalar::Scalar; +use super::super::transcript::ProofTranscript; +use crate::group::DecompressEncodedPoint; +use core::iter; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct BulletReductionProof { + L_vec: Vec, + R_vec: Vec, +} + +impl BulletReductionProof { + /// Create an inner-product proof. + /// + /// The proof is created with respect to the bases \\(G\\). + /// + /// The `transcript` is passed in as a parameter so that the + /// challenges depend on the *entire* transcript (including parent + /// protocols). + /// + /// The lengths of the vectors must all be the same, and must all be + /// either 0 or a power of 2. + pub fn prove( + transcript: &mut Transcript, + Q: &GroupElement, + G_vec: &[GroupElement], + H: &GroupElement, + a_vec: &[Scalar], + b_vec: &[Scalar], + blind: &Scalar, + blinds_vec: &[(Scalar, Scalar)], + ) -> ( + BulletReductionProof, + GroupElement, + Scalar, + Scalar, + GroupElement, + Scalar, + ) { + // Create slices G, H, a, b backed by their respective + // vectors. This lets us reslice as we compress the lengths + // of the vectors in the main loop below. + let mut G = &mut G_vec.to_owned()[..]; + let mut a = &mut a_vec.to_owned()[..]; + let mut b = &mut b_vec.to_owned()[..]; + + // All of the input vectors must have a length that is a power of two. + let mut n = G.len(); + assert!(n.is_power_of_two()); + let lg_n = n.log_2(); + + // All of the input vectors must have the same length. + assert_eq!(G.len(), n); + assert_eq!(a.len(), n); + assert_eq!(b.len(), n); + assert_eq!(blinds_vec.len(), 2 * lg_n); + + let mut L_vec = Vec::with_capacity(lg_n); + let mut R_vec = Vec::with_capacity(lg_n); + let mut blinds_iter = blinds_vec.iter(); + let mut blind_fin = *blind; + + while n != 1 { + n /= 2; + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + + let c_L = inner_product(a_L, b_R); + let c_R = inner_product(a_R, b_L); + + let (blind_L, blind_R) = blinds_iter.next().unwrap(); + + let L = GroupElement::vartime_multiscalar_mul( + a_L + .iter() + .chain(iter::once(&c_L)) + .chain(iter::once(blind_L)) + .map(|s| *s) + .collect(), + G_R + .iter() + .chain(iter::once(Q)) + .chain(iter::once(H)) + .map(|s| *s) + .collect(), + ); + + let R = GroupElement::vartime_multiscalar_mul( + a_R + .iter() + .chain(iter::once(&c_R)) + .chain(iter::once(blind_R)) + .map(|s| *s) + .collect(), + G_L + .iter() + .chain(iter::once(Q)) + .chain(iter::once(H)) + .map(|s| *s) + .collect(), + ); + + transcript.append_point(b"L", &L.compress()); + transcript.append_point(b"R", &R.compress()); + + let u = transcript.challenge_scalar(b"u"); + let u_inv = u.invert().unwrap(); + + for i in 0..n { + a_L[i] = a_L[i] * u + u_inv * a_R[i]; + b_L[i] = b_L[i] * u_inv + u * b_R[i]; + G_L[i] = + GroupElement::vartime_multiscalar_mul([u_inv, u].to_vec(), [G_L[i], G_R[i]].to_vec()); + } + + blind_fin = blind_fin + blind_L * u * u + blind_R * u_inv * u_inv; + + L_vec.push(L.compress()); + R_vec.push(R.compress()); + + a = a_L; + b = b_L; + G = G_L; + } + + let Gamma_hat = GroupElement::vartime_multiscalar_mul( + [a[0], a[0] * b[0], blind_fin].to_vec(), + [G[0], *Q, *H].to_vec(), + ); + + ( + BulletReductionProof { L_vec, R_vec }, + Gamma_hat, + a[0], + b[0], + G[0], + blind_fin, + ) + } + + /// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and \\([s\_{i}]\\) for combined multiscalar multiplication + /// in a parent protocol. See [inner product protocol notes](index.html#verification-equation) for details. + /// The verifier must provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner product proof. + fn verification_scalars( + &self, + n: usize, + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, Vec), ProofVerifyError> { + let lg_n = self.L_vec.len(); + if lg_n >= 32 { + // 4 billion multiplications should be enough for anyone + // and this check prevents overflow in 1< Result<(GroupElement, GroupElement, Scalar), ProofVerifyError> { + let (u_sq, u_inv_sq, s) = self.verification_scalars(n, transcript)?; + + let Ls = self + .L_vec + .iter() + .map(|p| p.decompress().ok_or(ProofVerifyError::InternalError)) + .collect::, _>>()?; + + let Rs = self + .R_vec + .iter() + .map(|p| p.decompress().ok_or(ProofVerifyError::InternalError)) + .collect::, _>>()?; + + let G_hat = GroupElement::vartime_multiscalar_mul(s.clone(), G.to_vec()); + let a_hat = inner_product(a, &s); + + let Gamma_hat = GroupElement::vartime_multiscalar_mul( + u_sq + .iter() + .chain(u_inv_sq.iter()) + .chain(iter::once(&Scalar::one())) + .map(|s| *s) + .collect(), + Ls.iter() + .chain(Rs.iter()) + .chain(iter::once(Gamma)) + .map(|p| *p) + .collect(), + ); + + Ok((G_hat, Gamma_hat, a_hat)) + } +} + +/// Computes an inner product of two vectors +/// \\[ +/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i. +/// \\] +/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal. +pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar { + assert!( + a.len() == b.len(), + "inner_product(a,b): lengths of vectors do not match" + ); + let mut out = Scalar::zero(); + for i in 0..a.len() { + out += a[i] * b[i]; + } + out +} diff --git a/packages/Spartan-secq/src/nizk/mod.rs b/packages/Spartan-secq/src/nizk/mod.rs new file mode 100644 index 0000000..f0db69e --- /dev/null +++ b/packages/Spartan-secq/src/nizk/mod.rs @@ -0,0 +1,735 @@ +#![allow(clippy::too_many_arguments)] +use super::commitments::{Commitments, MultiCommitGens}; +use super::errors::ProofVerifyError; +use super::group::{CompressedGroup, CompressedGroupExt}; +use super::math::Math; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use crate::group::DecompressEncodedPoint; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +mod bullet; +use bullet::BulletReductionProof; + +#[derive(Serialize, Deserialize, Debug)] +pub struct KnowledgeProof { + alpha: CompressedGroup, + z1: Scalar, + z2: Scalar, +} + +impl KnowledgeProof { + fn protocol_name() -> &'static [u8] { + b"knowledge proof" + } + + pub fn prove( + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + x: &Scalar, + r: &Scalar, + ) -> (KnowledgeProof, CompressedGroup) { + transcript.append_protocol_name(KnowledgeProof::protocol_name()); + + // produce two random Scalars + let t1 = random_tape.random_scalar(b"t1"); + let t2 = random_tape.random_scalar(b"t2"); + + let C = x.commit(r, gens_n).compress(); + C.append_to_transcript(b"C", transcript); + + let alpha = t1.commit(&t2, gens_n).compress(); + alpha.append_to_transcript(b"alpha", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let z1 = x * c + t1; + let z2 = r * c + t2; + + (KnowledgeProof { alpha, z1, z2 }, C) + } + + pub fn verify( + &self, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + C: &CompressedGroup, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(KnowledgeProof::protocol_name()); + C.append_to_transcript(b"C", transcript); + self.alpha.append_to_transcript(b"alpha", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let lhs = self.z1.commit(&self.z2, gens_n).compress(); + let rhs = (c * C.unpack()? + self.alpha.unpack()?).compress(); + + if lhs == rhs { + Ok(()) + } else { + Err(ProofVerifyError::InternalError) + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct EqualityProof { + alpha: CompressedGroup, + z: Scalar, +} + +impl EqualityProof { + fn protocol_name() -> &'static [u8] { + b"equality proof" + } + + pub fn prove( + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + v1: &Scalar, + s1: &Scalar, + v2: &Scalar, + s2: &Scalar, + ) -> (EqualityProof, CompressedGroup, CompressedGroup) { + transcript.append_protocol_name(EqualityProof::protocol_name()); + + // produce a random Scalar + let r = random_tape.random_scalar(b"r"); + + let C1 = v1.commit(s1, gens_n).compress(); + C1.append_to_transcript(b"C1", transcript); + + let C2 = v2.commit(s2, gens_n).compress(); + C2.append_to_transcript(b"C2", transcript); + + let alpha = (r * gens_n.h).compress(); + alpha.append_to_transcript(b"alpha", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let z = c * (s1 - s2) + r; + + (EqualityProof { alpha, z }, C1, C2) + } + + pub fn verify( + &self, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + C1: &CompressedGroup, + C2: &CompressedGroup, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(EqualityProof::protocol_name()); + C1.append_to_transcript(b"C1", transcript); + C2.append_to_transcript(b"C2", transcript); + self.alpha.append_to_transcript(b"alpha", transcript); + + let c = transcript.challenge_scalar(b"c"); + let rhs = { + let C = C1.unpack()? - C2.unpack()?; + (c * C + self.alpha.unpack()?).compress() + }; + + let lhs = (self.z * gens_n.h).compress(); + + if lhs == rhs { + Ok(()) + } else { + Err(ProofVerifyError::InternalError) + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ProductProof { + alpha: CompressedGroup, + beta: CompressedGroup, + delta: CompressedGroup, + z: [Scalar; 5], +} + +impl ProductProof { + fn protocol_name() -> &'static [u8] { + b"product proof" + } + + pub fn prove( + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + x: &Scalar, + rX: &Scalar, + y: &Scalar, + rY: &Scalar, + z: &Scalar, + rZ: &Scalar, + ) -> ( + ProductProof, + CompressedGroup, + CompressedGroup, + CompressedGroup, + ) { + transcript.append_protocol_name(ProductProof::protocol_name()); + + // produce five random Scalar + let b1 = random_tape.random_scalar(b"b1"); + let b2 = random_tape.random_scalar(b"b2"); + let b3 = random_tape.random_scalar(b"b3"); + let b4 = random_tape.random_scalar(b"b4"); + let b5 = random_tape.random_scalar(b"b5"); + + let X = x.commit(rX, gens_n).compress(); + X.append_to_transcript(b"X", transcript); + + let Y = y.commit(rY, gens_n).compress(); + Y.append_to_transcript(b"Y", transcript); + + let Z = z.commit(rZ, gens_n).compress(); + Z.append_to_transcript(b"Z", transcript); + + let alpha = b1.commit(&b2, gens_n).compress(); + alpha.append_to_transcript(b"alpha", transcript); + + let beta = b3.commit(&b4, gens_n).compress(); + beta.append_to_transcript(b"beta", transcript); + + let delta = { + let gens_X = &MultiCommitGens { + n: 1, + G: vec![X.decompress().unwrap()], + h: gens_n.h, + }; + b3.commit(&b5, gens_X).compress() + }; + delta.append_to_transcript(b"delta", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let z1 = b1 + c * x; + let z2 = b2 + c * rX; + let z3 = b3 + c * y; + let z4 = b4 + c * rY; + let z5 = b5 + c * (rZ - rX * y); + let z = [z1, z2, z3, z4, z5]; + + ( + ProductProof { + alpha, + beta, + delta, + z, + }, + X, + Y, + Z, + ) + } + + fn check_equality( + P: &CompressedGroup, + X: &CompressedGroup, + c: &Scalar, + gens_n: &MultiCommitGens, + z1: &Scalar, + z2: &Scalar, + ) -> bool { + let lhs = (P.decompress().unwrap() + c * X.decompress().unwrap()).compress(); + let rhs = z1.commit(z2, gens_n).compress(); + + lhs == rhs + } + + pub fn verify( + &self, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + X: &CompressedGroup, + Y: &CompressedGroup, + Z: &CompressedGroup, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(ProductProof::protocol_name()); + + X.append_to_transcript(b"X", transcript); + Y.append_to_transcript(b"Y", transcript); + Z.append_to_transcript(b"Z", transcript); + self.alpha.append_to_transcript(b"alpha", transcript); + self.beta.append_to_transcript(b"beta", transcript); + self.delta.append_to_transcript(b"delta", transcript); + + let z1 = self.z[0]; + let z2 = self.z[1]; + let z3 = self.z[2]; + let z4 = self.z[3]; + let z5 = self.z[4]; + + let c = transcript.challenge_scalar(b"c"); + + if ProductProof::check_equality(&self.alpha, X, &c, gens_n, &z1, &z2) + && ProductProof::check_equality(&self.beta, Y, &c, gens_n, &z3, &z4) + && ProductProof::check_equality( + &self.delta, + Z, + &c, + &MultiCommitGens { + n: 1, + G: vec![X.unpack()?], + h: gens_n.h, + }, + &z3, + &z5, + ) + { + Ok(()) + } else { + Err(ProofVerifyError::InternalError) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DotProductProof { + delta: CompressedGroup, + beta: CompressedGroup, + z: Vec, + z_delta: Scalar, + z_beta: Scalar, +} + +impl DotProductProof { + fn protocol_name() -> &'static [u8] { + b"dot product proof" + } + + pub fn compute_dotproduct(a: &[Scalar], b: &[Scalar]) -> Scalar { + assert_eq!(a.len(), b.len()); + (0..a.len()).map(|i| a[i] * b[i]).sum() + } + + pub fn prove( + gens_1: &MultiCommitGens, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + x_vec: &[Scalar], + blind_x: &Scalar, + a_vec: &[Scalar], + y: &Scalar, + blind_y: &Scalar, + ) -> (DotProductProof, CompressedGroup, CompressedGroup) { + transcript.append_protocol_name(DotProductProof::protocol_name()); + + let n = x_vec.len(); + assert_eq!(x_vec.len(), a_vec.len()); + assert_eq!(gens_n.n, a_vec.len()); + assert_eq!(gens_1.n, 1); + + // produce randomness for the proofs + let d_vec = random_tape.random_vector(b"d_vec", n); + let r_delta = random_tape.random_scalar(b"r_delta"); + let r_beta = random_tape.random_scalar(b"r_beta"); + + let Cx = x_vec.commit(blind_x, gens_n).compress(); + Cx.append_to_transcript(b"Cx", transcript); + + let Cy = y.commit(blind_y, gens_1).compress(); + Cy.append_to_transcript(b"Cy", transcript); + + a_vec.append_to_transcript(b"a", transcript); + + let delta = d_vec.commit(&r_delta, gens_n).compress(); + delta.append_to_transcript(b"delta", transcript); + + let dotproduct_a_d = DotProductProof::compute_dotproduct(a_vec, &d_vec); + + let beta = dotproduct_a_d.commit(&r_beta, gens_1).compress(); + beta.append_to_transcript(b"beta", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let z = (0..d_vec.len()) + .map(|i| c * x_vec[i] + d_vec[i]) + .collect::>(); + + let z_delta = c * blind_x + r_delta; + let z_beta = c * blind_y + r_beta; + + ( + DotProductProof { + delta, + beta, + z, + z_delta, + z_beta, + }, + Cx, + Cy, + ) + } + + pub fn verify( + &self, + gens_1: &MultiCommitGens, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + a: &[Scalar], + Cx: &CompressedGroup, + Cy: &CompressedGroup, + ) -> Result<(), ProofVerifyError> { + assert_eq!(gens_n.n, a.len()); + assert_eq!(gens_1.n, 1); + + transcript.append_protocol_name(DotProductProof::protocol_name()); + Cx.append_to_transcript(b"Cx", transcript); + Cy.append_to_transcript(b"Cy", transcript); + a.append_to_transcript(b"a", transcript); + self.delta.append_to_transcript(b"delta", transcript); + self.beta.append_to_transcript(b"beta", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let mut result = + c * Cx.unpack()? + self.delta.unpack()? == self.z.commit(&self.z_delta, gens_n); + + let dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, a); + result &= c * Cy.unpack()? + self.beta.unpack()? == dotproduct_z_a.commit(&self.z_beta, gens_1); + + if result { + Ok(()) + } else { + Err(ProofVerifyError::InternalError) + } + } +} + +pub struct DotProductProofGens { + n: usize, + pub gens_n: MultiCommitGens, + pub gens_1: MultiCommitGens, +} + +impl DotProductProofGens { + pub fn new(n: usize, label: &[u8]) -> Self { + let (gens_n, gens_1) = MultiCommitGens::new(n + 1, label).split_at(n); + DotProductProofGens { n, gens_n, gens_1 } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DotProductProofLog { + bullet_reduction_proof: BulletReductionProof, + delta: CompressedGroup, + beta: CompressedGroup, + z1: Scalar, + z2: Scalar, +} + +impl DotProductProofLog { + fn protocol_name() -> &'static [u8] { + b"dot product proof (log)" + } + + pub fn compute_dotproduct(a: &[Scalar], b: &[Scalar]) -> Scalar { + assert_eq!(a.len(), b.len()); + (0..a.len()).map(|i| a[i] * b[i]).sum() + } + + pub fn prove( + gens: &DotProductProofGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + x_vec: &[Scalar], + blind_x: &Scalar, + a_vec: &[Scalar], + y: &Scalar, + blind_y: &Scalar, + ) -> (DotProductProofLog, CompressedGroup, CompressedGroup) { + transcript.append_protocol_name(DotProductProofLog::protocol_name()); + + let n = x_vec.len(); + assert_eq!(x_vec.len(), a_vec.len()); + assert_eq!(gens.n, n); + + // produce randomness for generating a proof + let d = random_tape.random_scalar(b"d"); + let r_delta = random_tape.random_scalar(b"r_delta"); + let r_beta = random_tape.random_scalar(b"r_delta"); + let blinds_vec = { + let v1 = random_tape.random_vector(b"blinds_vec_1", 2 * n.log_2()); + let v2 = random_tape.random_vector(b"blinds_vec_2", 2 * n.log_2()); + (0..v1.len()) + .map(|i| (v1[i], v2[i])) + .collect::>() + }; + + let Cx = x_vec.commit(blind_x, &gens.gens_n).compress(); + Cx.append_to_transcript(b"Cx", transcript); + + let Cy = y.commit(blind_y, &gens.gens_1).compress(); + Cy.append_to_transcript(b"Cy", transcript); + + a_vec.append_to_transcript(b"a", transcript); + + // sample a random base and scale the generator used for + // the output of the inner product + let r = transcript.challenge_scalar(b"r"); + let gens_1_scaled = gens.gens_1.scale(&r); + + let blind_Gamma = blind_x + r * blind_y; + let (bullet_reduction_proof, _Gamma_hat, x_hat, a_hat, g_hat, rhat_Gamma) = + BulletReductionProof::prove( + transcript, + &gens_1_scaled.G[0], + &gens.gens_n.G, + &gens.gens_n.h, + x_vec, + a_vec, + &blind_Gamma, + &blinds_vec, + ); + let y_hat = x_hat * a_hat; + + let delta = { + let gens_hat = MultiCommitGens { + n: 1, + G: vec![g_hat], + h: gens.gens_1.h, + }; + d.commit(&r_delta, &gens_hat).compress() + }; + delta.append_to_transcript(b"delta", transcript); + + let beta = d.commit(&r_beta, &gens_1_scaled).compress(); + beta.append_to_transcript(b"beta", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let z1 = d + c * y_hat; + let z2 = a_hat * (c * rhat_Gamma + r_beta) + r_delta; + + ( + DotProductProofLog { + bullet_reduction_proof, + delta, + beta, + z1, + z2, + }, + Cx, + Cy, + ) + } + + pub fn verify( + &self, + n: usize, + gens: &DotProductProofGens, + transcript: &mut Transcript, + a: &[Scalar], + Cx: &CompressedGroup, + Cy: &CompressedGroup, + ) -> Result<(), ProofVerifyError> { + assert_eq!(gens.n, n); + assert_eq!(a.len(), n); + + transcript.append_protocol_name(DotProductProofLog::protocol_name()); + Cx.append_to_transcript(b"Cx", transcript); + Cy.append_to_transcript(b"Cy", transcript); + a.append_to_transcript(b"a", transcript); + + // sample a random base and scale the generator used for + // the output of the inner product + let r = transcript.challenge_scalar(b"r"); + let gens_1_scaled = gens.gens_1.scale(&r); + + let Gamma = Cx.unpack()? + r * Cy.unpack()?; + + let (g_hat, Gamma_hat, a_hat) = + self + .bullet_reduction_proof + .verify(n, a, transcript, &Gamma, &gens.gens_n.G)?; + self.delta.append_to_transcript(b"delta", transcript); + self.beta.append_to_transcript(b"beta", transcript); + + let c = transcript.challenge_scalar(b"c"); + + let c_s = &c; + let beta_s = self.beta.unpack()?; + let a_hat_s = &a_hat; + let delta_s = self.delta.unpack()?; + let z1_s = &self.z1; + let z2_s = &self.z2; + + let lhs = ((Gamma_hat * c_s + beta_s) * a_hat_s + delta_s).compress(); + let rhs = ((g_hat + gens_1_scaled.G[0] * a_hat_s) * z1_s + gens_1_scaled.h * z2_s).compress(); + + assert_eq!(lhs, rhs); + + if lhs == rhs { + Ok(()) + } else { + Err(ProofVerifyError::InternalError) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_core::OsRng; + #[test] + fn check_knowledgeproof() { + let mut csprng: OsRng = OsRng; + + let gens_1 = MultiCommitGens::new(1, b"test-knowledgeproof"); + + let x = Scalar::random(&mut csprng); + let r = Scalar::random(&mut csprng); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, committed_value) = + KnowledgeProof::prove(&gens_1, &mut prover_transcript, &mut random_tape, &x, &r); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(&gens_1, &mut verifier_transcript, &committed_value) + .is_ok()); + } + + #[test] + fn check_equalityproof() { + let mut csprng: OsRng = OsRng; + + let gens_1 = MultiCommitGens::new(1, b"test-equalityproof"); + let v1 = Scalar::random(&mut csprng); + let v2 = v1; + let s1 = Scalar::random(&mut csprng); + let s2 = Scalar::random(&mut csprng); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, C1, C2) = EqualityProof::prove( + &gens_1, + &mut prover_transcript, + &mut random_tape, + &v1, + &s1, + &v2, + &s2, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(&gens_1, &mut verifier_transcript, &C1, &C2) + .is_ok()); + } + + #[test] + fn check_productproof() { + let mut csprng: OsRng = OsRng; + + let gens_1 = MultiCommitGens::new(1, b"test-productproof"); + let x = Scalar::random(&mut csprng); + let rX = Scalar::random(&mut csprng); + let y = Scalar::random(&mut csprng); + let rY = Scalar::random(&mut csprng); + let z = x * y; + let rZ = Scalar::random(&mut csprng); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, X, Y, Z) = ProductProof::prove( + &gens_1, + &mut prover_transcript, + &mut random_tape, + &x, + &rX, + &y, + &rY, + &z, + &rZ, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(&gens_1, &mut verifier_transcript, &X, &Y, &Z) + .is_ok()); + } + + #[test] + fn check_dotproductproof() { + let mut csprng: OsRng = OsRng; + + let n = 1024; + + let gens_1 = MultiCommitGens::new(1, b"test-two"); + let gens_1024 = MultiCommitGens::new(n, b"test-1024"); + + let mut x: Vec = Vec::new(); + let mut a: Vec = Vec::new(); + for _ in 0..n { + x.push(Scalar::random(&mut csprng)); + a.push(Scalar::random(&mut csprng)); + } + let y = DotProductProofLog::compute_dotproduct(&x, &a); + let r_x = Scalar::random(&mut csprng); + let r_y = Scalar::random(&mut csprng); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, Cx, Cy) = DotProductProof::prove( + &gens_1, + &gens_1024, + &mut prover_transcript, + &mut random_tape, + &x, + &r_x, + &a, + &y, + &r_y, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(&gens_1, &gens_1024, &mut verifier_transcript, &a, &Cx, &Cy) + .is_ok()); + } + + #[test] + fn check_dotproductproof_log() { + let mut csprng: OsRng = OsRng; + + let n = 1024; + + let gens = DotProductProofGens::new(n, b"test-1024"); + + let x: Vec = (0..n).map(|_i| Scalar::random(&mut csprng)).collect(); + let a: Vec = (0..n).map(|_i| Scalar::random(&mut csprng)).collect(); + let y = DotProductProof::compute_dotproduct(&x, &a); + + let r_x = Scalar::random(&mut csprng); + let r_y = Scalar::random(&mut csprng); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, Cx, Cy) = DotProductProofLog::prove( + &gens, + &mut prover_transcript, + &mut random_tape, + &x, + &r_x, + &a, + &y, + &r_y, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify(n, &gens, &mut verifier_transcript, &a, &Cx, &Cy) + .is_ok()); + } +} diff --git a/packages/Spartan-secq/src/product_tree.rs b/packages/Spartan-secq/src/product_tree.rs new file mode 100644 index 0000000..6e2f932 --- /dev/null +++ b/packages/Spartan-secq/src/product_tree.rs @@ -0,0 +1,486 @@ +#![allow(dead_code)] +use super::dense_mlpoly::DensePolynomial; +use super::dense_mlpoly::EqPolynomial; +use super::math::Math; +use super::scalar::Scalar; +use super::sumcheck::SumcheckInstanceProof; +use super::transcript::ProofTranscript; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[derive(Debug)] +pub struct ProductCircuit { + left_vec: Vec, + right_vec: Vec, +} + +impl ProductCircuit { + fn compute_layer( + inp_left: &DensePolynomial, + inp_right: &DensePolynomial, + ) -> (DensePolynomial, DensePolynomial) { + let len = inp_left.len() + inp_right.len(); + let outp_left = (0..len / 4) + .map(|i| inp_left[i] * inp_right[i]) + .collect::>(); + let outp_right = (len / 4..len / 2) + .map(|i| inp_left[i] * inp_right[i]) + .collect::>(); + + ( + DensePolynomial::new(outp_left), + DensePolynomial::new(outp_right), + ) + } + + pub fn new(poly: &DensePolynomial) -> Self { + let mut left_vec: Vec = Vec::new(); + let mut right_vec: Vec = Vec::new(); + + let num_layers = poly.len().log_2(); + let (outp_left, outp_right) = poly.split(poly.len() / 2); + + left_vec.push(outp_left); + right_vec.push(outp_right); + + for i in 0..num_layers - 1 { + let (outp_left, outp_right) = ProductCircuit::compute_layer(&left_vec[i], &right_vec[i]); + left_vec.push(outp_left); + right_vec.push(outp_right); + } + + ProductCircuit { + left_vec, + right_vec, + } + } + + pub fn evaluate(&self) -> Scalar { + let len = self.left_vec.len(); + assert_eq!(self.left_vec[len - 1].get_num_vars(), 0); + assert_eq!(self.right_vec[len - 1].get_num_vars(), 0); + self.left_vec[len - 1][0] * self.right_vec[len - 1][0] + } +} + +pub struct DotProductCircuit { + left: DensePolynomial, + right: DensePolynomial, + weight: DensePolynomial, +} + +impl DotProductCircuit { + pub fn new(left: DensePolynomial, right: DensePolynomial, weight: DensePolynomial) -> Self { + assert_eq!(left.len(), right.len()); + assert_eq!(left.len(), weight.len()); + DotProductCircuit { + left, + right, + weight, + } + } + + pub fn evaluate(&self) -> Scalar { + (0..self.left.len()) + .map(|i| self.left[i] * self.right[i] * self.weight[i]) + .sum() + } + + pub fn split(&mut self) -> (DotProductCircuit, DotProductCircuit) { + let idx = self.left.len() / 2; + assert_eq!(idx * 2, self.left.len()); + let (l1, l2) = self.left.split(idx); + let (r1, r2) = self.right.split(idx); + let (w1, w2) = self.weight.split(idx); + ( + DotProductCircuit { + left: l1, + right: r1, + weight: w1, + }, + DotProductCircuit { + left: l2, + right: r2, + weight: w2, + }, + ) + } +} + +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct LayerProof { + pub proof: SumcheckInstanceProof, + pub claims: Vec, +} + +#[allow(dead_code)] +impl LayerProof { + pub fn verify( + &self, + claim: Scalar, + num_rounds: usize, + degree_bound: usize, + transcript: &mut Transcript, + ) -> (Scalar, Vec) { + self + .proof + .verify(claim, num_rounds, degree_bound, transcript) + .unwrap() + } +} + +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct LayerProofBatched { + pub proof: SumcheckInstanceProof, + pub claims_prod_left: Vec, + pub claims_prod_right: Vec, +} + +#[allow(dead_code)] +impl LayerProofBatched { + pub fn verify( + &self, + claim: Scalar, + num_rounds: usize, + degree_bound: usize, + transcript: &mut Transcript, + ) -> (Scalar, Vec) { + self + .proof + .verify(claim, num_rounds, degree_bound, transcript) + .unwrap() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProductCircuitEvalProof { + proof: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProductCircuitEvalProofBatched { + proof: Vec, + claims_dotp: (Vec, Vec, Vec), +} + +impl ProductCircuitEvalProof { + #![allow(dead_code)] + pub fn prove( + circuit: &mut ProductCircuit, + transcript: &mut Transcript, + ) -> (Self, Scalar, Vec) { + let mut proof: Vec = Vec::new(); + let num_layers = circuit.left_vec.len(); + + let mut claim = circuit.evaluate(); + let mut rand = Vec::new(); + for layer_id in (0..num_layers).rev() { + let len = circuit.left_vec[layer_id].len() + circuit.right_vec[layer_id].len(); + + let mut poly_C = DensePolynomial::new(EqPolynomial::new(rand.clone()).evals()); + assert_eq!(poly_C.len(), len / 2); + + let num_rounds_prod = poly_C.len().log_2(); + let comb_func_prod = |poly_A_comp: &Scalar, + poly_B_comp: &Scalar, + poly_C_comp: &Scalar| + -> Scalar { poly_A_comp * poly_B_comp * poly_C_comp }; + let (proof_prod, rand_prod, claims_prod) = SumcheckInstanceProof::prove_cubic( + &claim, + num_rounds_prod, + &mut circuit.left_vec[layer_id], + &mut circuit.right_vec[layer_id], + &mut poly_C, + comb_func_prod, + transcript, + ); + + transcript.append_scalar(b"claim_prod_left", &claims_prod[0]); + transcript.append_scalar(b"claim_prod_right", &claims_prod[1]); + + // produce a random challenge + let r_layer = transcript.challenge_scalar(b"challenge_r_layer"); + claim = claims_prod[0] + r_layer * (claims_prod[1] - claims_prod[0]); + + let mut ext = vec![r_layer]; + ext.extend(rand_prod); + rand = ext; + + proof.push(LayerProof { + proof: proof_prod, + claims: claims_prod[0..claims_prod.len() - 1].to_vec(), + }); + } + + (ProductCircuitEvalProof { proof }, claim, rand) + } + + pub fn verify( + &self, + eval: Scalar, + len: usize, + transcript: &mut Transcript, + ) -> (Scalar, Vec) { + let num_layers = len.log_2(); + let mut claim = eval; + let mut rand: Vec = Vec::new(); + //let mut num_rounds = 0; + assert_eq!(self.proof.len(), num_layers); + for (num_rounds, i) in (0..num_layers).enumerate() { + let (claim_last, rand_prod) = self.proof[i].verify(claim, num_rounds, 3, transcript); + + let claims_prod = &self.proof[i].claims; + transcript.append_scalar(b"claim_prod_left", &claims_prod[0]); + transcript.append_scalar(b"claim_prod_right", &claims_prod[1]); + + assert_eq!(rand.len(), rand_prod.len()); + let eq: Scalar = (0..rand.len()) + .map(|i| { + rand[i] * rand_prod[i] + (Scalar::one() - rand[i]) * (Scalar::one() - rand_prod[i]) + }) + .product(); + assert_eq!(claims_prod[0] * claims_prod[1] * eq, claim_last); + + // produce a random challenge + let r_layer = transcript.challenge_scalar(b"challenge_r_layer"); + claim = (Scalar::one() - r_layer) * claims_prod[0] + r_layer * claims_prod[1]; + let mut ext = vec![r_layer]; + ext.extend(rand_prod); + rand = ext; + } + + (claim, rand) + } +} + +impl ProductCircuitEvalProofBatched { + pub fn prove( + prod_circuit_vec: &mut Vec<&mut ProductCircuit>, + dotp_circuit_vec: &mut Vec<&mut DotProductCircuit>, + transcript: &mut Transcript, + ) -> (Self, Vec) { + assert!(!prod_circuit_vec.is_empty()); + + let mut claims_dotp_final = (Vec::new(), Vec::new(), Vec::new()); + + let mut proof_layers: Vec = Vec::new(); + let num_layers = prod_circuit_vec[0].left_vec.len(); + let mut claims_to_verify = (0..prod_circuit_vec.len()) + .map(|i| prod_circuit_vec[i].evaluate()) + .collect::>(); + let mut rand = Vec::new(); + for layer_id in (0..num_layers).rev() { + // prepare paralell instance that share poly_C first + let len = prod_circuit_vec[0].left_vec[layer_id].len() + + prod_circuit_vec[0].right_vec[layer_id].len(); + + let mut poly_C_par = DensePolynomial::new(EqPolynomial::new(rand.clone()).evals()); + assert_eq!(poly_C_par.len(), len / 2); + + let num_rounds_prod = poly_C_par.len().log_2(); + let comb_func_prod = |poly_A_comp: &Scalar, + poly_B_comp: &Scalar, + poly_C_comp: &Scalar| + -> Scalar { poly_A_comp * poly_B_comp * poly_C_comp }; + + let mut poly_A_batched_par: Vec<&mut DensePolynomial> = Vec::new(); + let mut poly_B_batched_par: Vec<&mut DensePolynomial> = Vec::new(); + for prod_circuit in prod_circuit_vec.iter_mut() { + poly_A_batched_par.push(&mut prod_circuit.left_vec[layer_id]); + poly_B_batched_par.push(&mut prod_circuit.right_vec[layer_id]) + } + let poly_vec_par = ( + &mut poly_A_batched_par, + &mut poly_B_batched_par, + &mut poly_C_par, + ); + + // prepare sequential instances that don't share poly_C + let mut poly_A_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); + let mut poly_B_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); + let mut poly_C_batched_seq: Vec<&mut DensePolynomial> = Vec::new(); + if layer_id == 0 && !dotp_circuit_vec.is_empty() { + // add additional claims + for item in dotp_circuit_vec.iter() { + claims_to_verify.push(item.evaluate()); + assert_eq!(len / 2, item.left.len()); + assert_eq!(len / 2, item.right.len()); + assert_eq!(len / 2, item.weight.len()); + } + + for dotp_circuit in dotp_circuit_vec.iter_mut() { + poly_A_batched_seq.push(&mut dotp_circuit.left); + poly_B_batched_seq.push(&mut dotp_circuit.right); + poly_C_batched_seq.push(&mut dotp_circuit.weight); + } + } + let poly_vec_seq = ( + &mut poly_A_batched_seq, + &mut poly_B_batched_seq, + &mut poly_C_batched_seq, + ); + + // produce a fresh set of coeffs and a joint claim + let coeff_vec = + transcript.challenge_vector(b"rand_coeffs_next_layer", claims_to_verify.len()); + let claim = (0..claims_to_verify.len()) + .map(|i| claims_to_verify[i] * coeff_vec[i]) + .sum(); + + let (proof, rand_prod, claims_prod, claims_dotp) = SumcheckInstanceProof::prove_cubic_batched( + &claim, + num_rounds_prod, + poly_vec_par, + poly_vec_seq, + &coeff_vec, + comb_func_prod, + transcript, + ); + + let (claims_prod_left, claims_prod_right, _claims_eq) = claims_prod; + for i in 0..prod_circuit_vec.len() { + transcript.append_scalar(b"claim_prod_left", &claims_prod_left[i]); + transcript.append_scalar(b"claim_prod_right", &claims_prod_right[i]); + } + + if layer_id == 0 && !dotp_circuit_vec.is_empty() { + let (claims_dotp_left, claims_dotp_right, claims_dotp_weight) = claims_dotp; + for i in 0..dotp_circuit_vec.len() { + transcript.append_scalar(b"claim_dotp_left", &claims_dotp_left[i]); + transcript.append_scalar(b"claim_dotp_right", &claims_dotp_right[i]); + transcript.append_scalar(b"claim_dotp_weight", &claims_dotp_weight[i]); + } + claims_dotp_final = (claims_dotp_left, claims_dotp_right, claims_dotp_weight); + } + + // produce a random challenge to condense two claims into a single claim + let r_layer = transcript.challenge_scalar(b"challenge_r_layer"); + + claims_to_verify = (0..prod_circuit_vec.len()) + .map(|i| claims_prod_left[i] + r_layer * (claims_prod_right[i] - claims_prod_left[i])) + .collect::>(); + + let mut ext = vec![r_layer]; + ext.extend(rand_prod); + rand = ext; + + proof_layers.push(LayerProofBatched { + proof, + claims_prod_left, + claims_prod_right, + }); + } + + ( + ProductCircuitEvalProofBatched { + proof: proof_layers, + claims_dotp: claims_dotp_final, + }, + rand, + ) + } + + pub fn verify( + &self, + claims_prod_vec: &[Scalar], + claims_dotp_vec: &[Scalar], + len: usize, + transcript: &mut Transcript, + ) -> (Vec, Vec, Vec) { + let num_layers = len.log_2(); + let mut rand: Vec = Vec::new(); + //let mut num_rounds = 0; + assert_eq!(self.proof.len(), num_layers); + + let mut claims_to_verify = claims_prod_vec.to_owned(); + let mut claims_to_verify_dotp: Vec = Vec::new(); + for (num_rounds, i) in (0..num_layers).enumerate() { + if i == num_layers - 1 { + claims_to_verify.extend(claims_dotp_vec); + } + + // produce random coefficients, one for each instance + let coeff_vec = + transcript.challenge_vector(b"rand_coeffs_next_layer", claims_to_verify.len()); + + // produce a joint claim + let claim = (0..claims_to_verify.len()) + .map(|i| claims_to_verify[i] * coeff_vec[i]) + .sum(); + + let (claim_last, rand_prod) = self.proof[i].verify(claim, num_rounds, 3, transcript); + + let claims_prod_left = &self.proof[i].claims_prod_left; + let claims_prod_right = &self.proof[i].claims_prod_right; + assert_eq!(claims_prod_left.len(), claims_prod_vec.len()); + assert_eq!(claims_prod_right.len(), claims_prod_vec.len()); + + for i in 0..claims_prod_vec.len() { + transcript.append_scalar(b"claim_prod_left", &claims_prod_left[i]); + transcript.append_scalar(b"claim_prod_right", &claims_prod_right[i]); + } + + assert_eq!(rand.len(), rand_prod.len()); + let eq: Scalar = (0..rand.len()) + .map(|i| { + rand[i] * rand_prod[i] + (Scalar::one() - rand[i]) * (Scalar::one() - rand_prod[i]) + }) + .product(); + let mut claim_expected: Scalar = (0..claims_prod_vec.len()) + .map(|i| coeff_vec[i] * (claims_prod_left[i] * claims_prod_right[i] * eq)) + .sum(); + + // add claims from the dotp instances + if i == num_layers - 1 { + let num_prod_instances = claims_prod_vec.len(); + let (claims_dotp_left, claims_dotp_right, claims_dotp_weight) = &self.claims_dotp; + for i in 0..claims_dotp_left.len() { + transcript.append_scalar(b"claim_dotp_left", &claims_dotp_left[i]); + transcript.append_scalar(b"claim_dotp_right", &claims_dotp_right[i]); + transcript.append_scalar(b"claim_dotp_weight", &claims_dotp_weight[i]); + + claim_expected += coeff_vec[i + num_prod_instances] + * claims_dotp_left[i] + * claims_dotp_right[i] + * claims_dotp_weight[i]; + } + } + + assert_eq!(claim_expected, claim_last); + + // produce a random challenge + let r_layer = transcript.challenge_scalar(b"challenge_r_layer"); + + claims_to_verify = (0..claims_prod_left.len()) + .map(|i| claims_prod_left[i] + r_layer * (claims_prod_right[i] - claims_prod_left[i])) + .collect::>(); + + // add claims to verify for dotp circuit + if i == num_layers - 1 { + let (claims_dotp_left, claims_dotp_right, claims_dotp_weight) = &self.claims_dotp; + + for i in 0..claims_dotp_vec.len() / 2 { + // combine left claims + let claim_left = claims_dotp_left[2 * i] + + r_layer * (claims_dotp_left[2 * i + 1] - claims_dotp_left[2 * i]); + + let claim_right = claims_dotp_right[2 * i] + + r_layer * (claims_dotp_right[2 * i + 1] - claims_dotp_right[2 * i]); + + let claim_weight = claims_dotp_weight[2 * i] + + r_layer * (claims_dotp_weight[2 * i + 1] - claims_dotp_weight[2 * i]); + claims_to_verify_dotp.push(claim_left); + claims_to_verify_dotp.push(claim_right); + claims_to_verify_dotp.push(claim_weight); + } + } + + let mut ext = vec![r_layer]; + ext.extend(rand_prod); + rand = ext; + } + (claims_to_verify, claims_to_verify_dotp, rand) + } +} diff --git a/packages/Spartan-secq/src/r1csinstance.rs b/packages/Spartan-secq/src/r1csinstance.rs new file mode 100644 index 0000000..cc83214 --- /dev/null +++ b/packages/Spartan-secq/src/r1csinstance.rs @@ -0,0 +1,367 @@ +use crate::transcript::AppendToTranscript; + +use super::dense_mlpoly::DensePolynomial; +use super::errors::ProofVerifyError; +use super::math::Math; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::sparse_mlpoly::{ + MultiSparseMatPolynomialAsDense, SparseMatEntry, SparseMatPolyCommitment, + SparseMatPolyCommitmentGens, SparseMatPolyEvalProof, SparseMatPolynomial, +}; +use super::timer::Timer; +use flate2::{write::ZlibEncoder, Compression}; +use merlin::Transcript; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct R1CSInstance { + num_cons: usize, + num_vars: usize, + num_inputs: usize, + A: SparseMatPolynomial, + B: SparseMatPolynomial, + C: SparseMatPolynomial, +} + +pub struct R1CSCommitmentGens { + gens: SparseMatPolyCommitmentGens, +} + +impl R1CSCommitmentGens { + pub fn new( + label: &'static [u8], + num_cons: usize, + num_vars: usize, + num_inputs: usize, + num_nz_entries: usize, + ) -> R1CSCommitmentGens { + assert!(num_inputs < num_vars); + let num_poly_vars_x = num_cons.log_2(); + let num_poly_vars_y = (2 * num_vars).log_2(); + let gens = + SparseMatPolyCommitmentGens::new(label, num_poly_vars_x, num_poly_vars_y, num_nz_entries, 3); + R1CSCommitmentGens { gens } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct R1CSCommitment { + num_cons: usize, + num_vars: usize, + num_inputs: usize, + comm: SparseMatPolyCommitment, +} + +impl AppendToTranscript for R1CSCommitment { + fn append_to_transcript(&self, _label: &'static [u8], transcript: &mut Transcript) { + transcript.append_u64(b"num_cons", self.num_cons as u64); + transcript.append_u64(b"num_vars", self.num_vars as u64); + transcript.append_u64(b"num_inputs", self.num_inputs as u64); + self.comm.append_to_transcript(b"comm", transcript); + } +} + +pub struct R1CSDecommitment { + dense: MultiSparseMatPolynomialAsDense, +} + +impl R1CSCommitment { + pub fn get_num_cons(&self) -> usize { + self.num_cons + } + + pub fn get_num_vars(&self) -> usize { + self.num_vars + } + + pub fn get_num_inputs(&self) -> usize { + self.num_inputs + } +} + +impl R1CSInstance { + pub fn new( + num_cons: usize, + num_vars: usize, + num_inputs: usize, + A: &[(usize, usize, Scalar)], + B: &[(usize, usize, Scalar)], + C: &[(usize, usize, Scalar)], + ) -> R1CSInstance { + Timer::print(&format!("number_of_constraints {}", num_cons)); + Timer::print(&format!("number_of_variables {}", num_vars)); + Timer::print(&format!("number_of_inputs {}", num_inputs)); + Timer::print(&format!("number_non-zero_entries_A {}", A.len())); + Timer::print(&format!("number_non-zero_entries_B {}", B.len())); + Timer::print(&format!("number_non-zero_entries_C {}", C.len())); + + // check that num_cons is a power of 2 + assert_eq!(num_cons.next_power_of_two(), num_cons); + + // check that num_vars is a power of 2 + assert_eq!(num_vars.next_power_of_two(), num_vars); + + // check that number_inputs + 1 <= num_vars + assert!(num_inputs < num_vars); + + // no errors, so create polynomials + let num_poly_vars_x = num_cons.log_2(); + let num_poly_vars_y = (2 * num_vars).log_2(); + + let mat_A = (0..A.len()) + .map(|i| SparseMatEntry::new(A[i].0, A[i].1, A[i].2)) + .collect::>(); + let mat_B = (0..B.len()) + .map(|i| SparseMatEntry::new(B[i].0, B[i].1, B[i].2)) + .collect::>(); + let mat_C = (0..C.len()) + .map(|i| SparseMatEntry::new(C[i].0, C[i].1, C[i].2)) + .collect::>(); + + let poly_A = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, mat_A); + let poly_B = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, mat_B); + let poly_C = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, mat_C); + + R1CSInstance { + num_cons, + num_vars, + num_inputs, + A: poly_A, + B: poly_B, + C: poly_C, + } + } + + pub fn get_num_vars(&self) -> usize { + self.num_vars + } + + pub fn get_num_cons(&self) -> usize { + self.num_cons + } + + pub fn get_num_inputs(&self) -> usize { + self.num_inputs + } + + pub fn get_digest(&self) -> Vec { + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + bincode::serialize_into(&mut encoder, &self).unwrap(); + encoder.finish().unwrap() + } + + pub fn produce_synthetic_r1cs( + num_cons: usize, + num_vars: usize, + num_inputs: usize, + ) -> (R1CSInstance, Vec, Vec) { + Timer::print(&format!("number_of_constraints {}", num_cons)); + Timer::print(&format!("number_of_variables {}", num_vars)); + Timer::print(&format!("number_of_inputs {}", num_inputs)); + + let mut csprng: OsRng = OsRng; + + // assert num_cons and num_vars are power of 2 + assert_eq!((num_cons.log_2()).pow2(), num_cons); + assert_eq!((num_vars.log_2()).pow2(), num_vars); + + // num_inputs + 1 <= num_vars + assert!(num_inputs < num_vars); + + // z is organized as [vars,1,io] + let size_z = num_vars + num_inputs + 1; + + // produce a random satisfying assignment + let Z = { + let mut Z: Vec = (0..size_z) + .map(|_i| Scalar::random(&mut csprng)) + .collect::>(); + Z[num_vars] = Scalar::one(); // set the constant term to 1 + Z + }; + + // three sparse matrices + let mut A: Vec = Vec::new(); + let mut B: Vec = Vec::new(); + let mut C: Vec = Vec::new(); + let one = Scalar::one(); + for i in 0..num_cons { + let A_idx = i % size_z; + let B_idx = (i + 2) % size_z; + A.push(SparseMatEntry::new(i, A_idx, one)); + B.push(SparseMatEntry::new(i, B_idx, one)); + let AB_val = Z[A_idx] * Z[B_idx]; + + let C_idx = (i + 3) % size_z; + let C_val = Z[C_idx]; + + if C_val == Scalar::zero() { + C.push(SparseMatEntry::new(i, num_vars, AB_val)); + } else { + C.push(SparseMatEntry::new( + i, + C_idx, + AB_val * C_val.invert().unwrap(), + )); + } + } + + Timer::print(&format!("number_non-zero_entries_A {}", A.len())); + Timer::print(&format!("number_non-zero_entries_B {}", B.len())); + Timer::print(&format!("number_non-zero_entries_C {}", C.len())); + + let num_poly_vars_x = num_cons.log_2(); + let num_poly_vars_y = (2 * num_vars).log_2(); + let poly_A = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, A); + let poly_B = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, B); + let poly_C = SparseMatPolynomial::new(num_poly_vars_x, num_poly_vars_y, C); + + let inst = R1CSInstance { + num_cons, + num_vars, + num_inputs, + A: poly_A, + B: poly_B, + C: poly_C, + }; + + assert!(inst.is_sat(&Z[..num_vars], &Z[num_vars + 1..])); + + (inst, Z[..num_vars].to_vec(), Z[num_vars + 1..].to_vec()) + } + + pub fn is_sat(&self, vars: &[Scalar], input: &[Scalar]) -> bool { + assert_eq!(vars.len(), self.num_vars); + assert_eq!(input.len(), self.num_inputs); + + let z = { + let mut z = vars.to_vec(); + z.extend(&vec![Scalar::one()]); + z.extend(input); + z + }; + + // verify if Az * Bz - Cz = [0...] + let Az = self + .A + .multiply_vec(self.num_cons, self.num_vars + self.num_inputs + 1, &z); + let Bz = self + .B + .multiply_vec(self.num_cons, self.num_vars + self.num_inputs + 1, &z); + let Cz = self + .C + .multiply_vec(self.num_cons, self.num_vars + self.num_inputs + 1, &z); + + assert_eq!(Az.len(), self.num_cons); + assert_eq!(Bz.len(), self.num_cons); + assert_eq!(Cz.len(), self.num_cons); + let res: usize = (0..self.num_cons) + .map(|i| usize::from(Az[i] * Bz[i] != Cz[i])) + .sum(); + + res == 0 + } + + pub fn multiply_vec( + &self, + num_rows: usize, + num_cols: usize, + z: &[Scalar], + ) -> (DensePolynomial, DensePolynomial, DensePolynomial) { + assert_eq!(num_rows, self.num_cons); + assert_eq!(z.len(), num_cols); + assert!(num_cols > self.num_vars); + ( + DensePolynomial::new(self.A.multiply_vec(num_rows, num_cols, z)), + DensePolynomial::new(self.B.multiply_vec(num_rows, num_cols, z)), + DensePolynomial::new(self.C.multiply_vec(num_rows, num_cols, z)), + ) + } + + pub fn compute_eval_table_sparse( + &self, + num_rows: usize, + num_cols: usize, + evals: &[Scalar], + ) -> (Vec, Vec, Vec) { + assert_eq!(num_rows, self.num_cons); + assert!(num_cols > self.num_vars); + + let evals_A = self.A.compute_eval_table_sparse(evals, num_rows, num_cols); + let evals_B = self.B.compute_eval_table_sparse(evals, num_rows, num_cols); + let evals_C = self.C.compute_eval_table_sparse(evals, num_rows, num_cols); + + (evals_A, evals_B, evals_C) + } + + pub fn evaluate(&self, rx: &[Scalar], ry: &[Scalar]) -> (Scalar, Scalar, Scalar) { + let evals = SparseMatPolynomial::multi_evaluate(&[&self.A, &self.B, &self.C], rx, ry); + (evals[0], evals[1], evals[2]) + } + + pub fn commit(&self, gens: &R1CSCommitmentGens) -> (R1CSCommitment, R1CSDecommitment) { + let (comm, dense) = SparseMatPolynomial::multi_commit(&[&self.A, &self.B, &self.C], &gens.gens); + let r1cs_comm = R1CSCommitment { + num_cons: self.num_cons, + num_vars: self.num_vars, + num_inputs: self.num_inputs, + comm, + }; + + let r1cs_decomm = R1CSDecommitment { dense }; + + (r1cs_comm, r1cs_decomm) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct R1CSEvalProof { + proof: SparseMatPolyEvalProof, +} + +impl R1CSEvalProof { + pub fn prove( + decomm: &R1CSDecommitment, + rx: &[Scalar], // point at which the polynomial is evaluated + ry: &[Scalar], + evals: &(Scalar, Scalar, Scalar), + gens: &R1CSCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> R1CSEvalProof { + let timer = Timer::new("R1CSEvalProof::prove"); + let proof = SparseMatPolyEvalProof::prove( + &decomm.dense, + rx, + ry, + &[evals.0, evals.1, evals.2], + &gens.gens, + transcript, + random_tape, + ); + timer.stop(); + + R1CSEvalProof { proof } + } + + pub fn verify( + &self, + comm: &R1CSCommitment, + rx: &[Scalar], // point at which the R1CS matrix polynomials are evaluated + ry: &[Scalar], + evals: &(Scalar, Scalar, Scalar), + gens: &R1CSCommitmentGens, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + self.proof.verify( + &comm.comm, + rx, + ry, + &[evals.0, evals.1, evals.2], + &gens.gens, + transcript, + ) + } +} diff --git a/packages/Spartan-secq/src/r1csproof.rs b/packages/Spartan-secq/src/r1csproof.rs new file mode 100644 index 0000000..b24570a --- /dev/null +++ b/packages/Spartan-secq/src/r1csproof.rs @@ -0,0 +1,608 @@ +#![allow(clippy::too_many_arguments)] +use super::commitments::{Commitments, MultiCommitGens}; +use super::dense_mlpoly::{ + DensePolynomial, EqPolynomial, PolyCommitment, PolyCommitmentGens, PolyEvalProof, +}; +use super::errors::ProofVerifyError; +use super::group::{CompressedGroup, GroupElement, VartimeMultiscalarMul}; +use super::math::Math; +use super::nizk::{EqualityProof, KnowledgeProof, ProductProof}; +use super::r1csinstance::R1CSInstance; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::sparse_mlpoly::{SparsePolyEntry, SparsePolynomial}; +use super::sumcheck::ZKSumcheckInstanceProof; +use super::timer::Timer; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use crate::group::DecompressEncodedPoint; +use core::iter; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct R1CSProof { + comm_vars: PolyCommitment, + sc_proof_phase1: ZKSumcheckInstanceProof, + claims_phase2: ( + CompressedGroup, + CompressedGroup, + CompressedGroup, + CompressedGroup, + ), + pok_claims_phase2: (KnowledgeProof, ProductProof), + proof_eq_sc_phase1: EqualityProof, + sc_proof_phase2: ZKSumcheckInstanceProof, + comm_vars_at_ry: CompressedGroup, + proof_eval_vars_at_ry: PolyEvalProof, + proof_eq_sc_phase2: EqualityProof, +} + +pub struct R1CSSumcheckGens { + gens_1: MultiCommitGens, + gens_3: MultiCommitGens, + gens_4: MultiCommitGens, +} + +// TODO: fix passing gens_1_ref +impl R1CSSumcheckGens { + pub fn new(label: &'static [u8], gens_1_ref: &MultiCommitGens) -> Self { + let gens_1 = gens_1_ref.clone(); + let gens_3 = MultiCommitGens::new(3, label); + let gens_4 = MultiCommitGens::new(4, label); + + R1CSSumcheckGens { + gens_1, + gens_3, + gens_4, + } + } +} + +pub struct R1CSGens { + gens_sc: R1CSSumcheckGens, + gens_pc: PolyCommitmentGens, +} + +impl R1CSGens { + pub fn new(label: &'static [u8], _num_cons: usize, num_vars: usize) -> Self { + let num_poly_vars = num_vars.log_2(); + let gens_pc = PolyCommitmentGens::new(num_poly_vars, label); + let gens_sc = R1CSSumcheckGens::new(label, &gens_pc.gens.gens_1); + R1CSGens { gens_sc, gens_pc } + } +} + +impl R1CSProof { + fn prove_phase_one( + num_rounds: usize, + evals_tau: &mut DensePolynomial, + evals_Az: &mut DensePolynomial, + evals_Bz: &mut DensePolynomial, + evals_Cz: &mut DensePolynomial, + gens: &R1CSSumcheckGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (ZKSumcheckInstanceProof, Vec, Vec, Scalar) { + let comb_func = |poly_A_comp: &Scalar, + poly_B_comp: &Scalar, + poly_C_comp: &Scalar, + poly_D_comp: &Scalar| + -> Scalar { poly_A_comp * (poly_B_comp * poly_C_comp - poly_D_comp) }; + + let (sc_proof_phase_one, r, claims, blind_claim_postsc) = + ZKSumcheckInstanceProof::prove_cubic_with_additive_term( + &Scalar::zero(), // claim is zero + &Scalar::zero(), // blind for claim is also zero + num_rounds, + evals_tau, + evals_Az, + evals_Bz, + evals_Cz, + comb_func, + &gens.gens_1, + &gens.gens_4, + transcript, + random_tape, + ); + + (sc_proof_phase_one, r, claims, blind_claim_postsc) + } + + fn prove_phase_two( + num_rounds: usize, + claim: &Scalar, + blind_claim: &Scalar, + evals_z: &mut DensePolynomial, + evals_ABC: &mut DensePolynomial, + gens: &R1CSSumcheckGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (ZKSumcheckInstanceProof, Vec, Vec, Scalar) { + let comb_func = + |poly_A_comp: &Scalar, poly_B_comp: &Scalar| -> Scalar { poly_A_comp * poly_B_comp }; + let (sc_proof_phase_two, r, claims, blind_claim_postsc) = ZKSumcheckInstanceProof::prove_quad( + claim, + blind_claim, + num_rounds, + evals_z, + evals_ABC, + comb_func, + &gens.gens_1, + &gens.gens_3, + transcript, + random_tape, + ); + + (sc_proof_phase_two, r, claims, blind_claim_postsc) + } + + fn protocol_name() -> &'static [u8] { + b"R1CS proof" + } + + pub fn prove( + inst: &R1CSInstance, + vars: Vec, + input: &[Scalar], + gens: &R1CSGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (R1CSProof, Vec, Vec) { + let timer_prove = Timer::new("R1CSProof::prove"); + transcript.append_protocol_name(R1CSProof::protocol_name()); + + // we currently require the number of |inputs| + 1 to be at most number of vars + assert!(input.len() < vars.len()); + + input.append_to_transcript(b"input", transcript); + + let timer_commit = Timer::new("polycommit"); + let (poly_vars, comm_vars, blinds_vars) = { + // create a multilinear polynomial using the supplied assignment for variables + let poly_vars = DensePolynomial::new(vars.clone()); + + // produce a commitment to the satisfying assignment + let (comm_vars, blinds_vars) = poly_vars.commit(&gens.gens_pc, Some(random_tape)); + + // add the commitment to the prover's transcript + comm_vars.append_to_transcript(b"poly_commitment", transcript); + (poly_vars, comm_vars, blinds_vars) + }; + timer_commit.stop(); + + let timer_sc_proof_phase1 = Timer::new("prove_sc_phase_one"); + + // append input to variables to create a single vector z + let z = { + let num_inputs = input.len(); + let num_vars = vars.len(); + let mut z = vars; + z.extend(&vec![Scalar::one()]); // add constant term in z + z.extend(input); + z.extend(&vec![Scalar::zero(); num_vars - num_inputs - 1]); // we will pad with zeros + z + }; + + // derive the verifier's challenge tau + let (num_rounds_x, num_rounds_y) = (inst.get_num_cons().log_2(), z.len().log_2()); + let tau = transcript.challenge_vector(b"challenge_tau", num_rounds_x); + // compute the initial evaluation table for R(\tau, x) + let mut poly_tau = DensePolynomial::new(EqPolynomial::new(tau).evals()); + let (mut poly_Az, mut poly_Bz, mut poly_Cz) = + inst.multiply_vec(inst.get_num_cons(), z.len(), &z); + + let (sc_proof_phase1, rx, _claims_phase1, blind_claim_postsc1) = R1CSProof::prove_phase_one( + num_rounds_x, + &mut poly_tau, + &mut poly_Az, + &mut poly_Bz, + &mut poly_Cz, + &gens.gens_sc, + transcript, + random_tape, + ); + assert_eq!(poly_tau.len(), 1); + assert_eq!(poly_Az.len(), 1); + assert_eq!(poly_Bz.len(), 1); + assert_eq!(poly_Cz.len(), 1); + timer_sc_proof_phase1.stop(); + + let (tau_claim, Az_claim, Bz_claim, Cz_claim) = + (&poly_tau[0], &poly_Az[0], &poly_Bz[0], &poly_Cz[0]); + let (Az_blind, Bz_blind, Cz_blind, prod_Az_Bz_blind) = ( + random_tape.random_scalar(b"Az_blind"), + random_tape.random_scalar(b"Bz_blind"), + random_tape.random_scalar(b"Cz_blind"), + random_tape.random_scalar(b"prod_Az_Bz_blind"), + ); + + let (pok_Cz_claim, comm_Cz_claim) = { + KnowledgeProof::prove( + &gens.gens_sc.gens_1, + transcript, + random_tape, + Cz_claim, + &Cz_blind, + ) + }; + + let (proof_prod, comm_Az_claim, comm_Bz_claim, comm_prod_Az_Bz_claims) = { + let prod = Az_claim * Bz_claim; + ProductProof::prove( + &gens.gens_sc.gens_1, + transcript, + random_tape, + Az_claim, + &Az_blind, + Bz_claim, + &Bz_blind, + &prod, + &prod_Az_Bz_blind, + ) + }; + + comm_Az_claim.append_to_transcript(b"comm_Az_claim", transcript); + comm_Bz_claim.append_to_transcript(b"comm_Bz_claim", transcript); + comm_Cz_claim.append_to_transcript(b"comm_Cz_claim", transcript); + comm_prod_Az_Bz_claims.append_to_transcript(b"comm_prod_Az_Bz_claims", transcript); + + // prove the final step of sum-check #1 + let taus_bound_rx = tau_claim; + let blind_expected_claim_postsc1 = taus_bound_rx * (prod_Az_Bz_blind - Cz_blind); + let claim_post_phase1 = (Az_claim * Bz_claim - Cz_claim) * taus_bound_rx; + let (proof_eq_sc_phase1, _C1, _C2) = EqualityProof::prove( + &gens.gens_sc.gens_1, + transcript, + random_tape, + &claim_post_phase1, + &blind_expected_claim_postsc1, + &claim_post_phase1, + &blind_claim_postsc1, + ); + + let timer_sc_proof_phase2 = Timer::new("prove_sc_phase_two"); + // combine the three claims into a single claim + let r_A = transcript.challenge_scalar(b"challenege_Az"); + let r_B = transcript.challenge_scalar(b"challenege_Bz"); + let r_C = transcript.challenge_scalar(b"challenege_Cz"); + let claim_phase2 = r_A * Az_claim + r_B * Bz_claim + r_C * Cz_claim; + let blind_claim_phase2 = r_A * Az_blind + r_B * Bz_blind + r_C * Cz_blind; + + let evals_ABC = { + // compute the initial evaluation table for R(\tau, x) + let evals_rx = EqPolynomial::new(rx.clone()).evals(); + let (evals_A, evals_B, evals_C) = + inst.compute_eval_table_sparse(inst.get_num_cons(), z.len(), &evals_rx); + + assert_eq!(evals_A.len(), evals_B.len()); + assert_eq!(evals_A.len(), evals_C.len()); + (0..evals_A.len()) + .map(|i| r_A * evals_A[i] + r_B * evals_B[i] + r_C * evals_C[i]) + .collect::>() + }; + + // another instance of the sum-check protocol + let (sc_proof_phase2, ry, claims_phase2, blind_claim_postsc2) = R1CSProof::prove_phase_two( + num_rounds_y, + &claim_phase2, + &blind_claim_phase2, + &mut DensePolynomial::new(z), + &mut DensePolynomial::new(evals_ABC), + &gens.gens_sc, + transcript, + random_tape, + ); + timer_sc_proof_phase2.stop(); + + let timer_polyeval = Timer::new("polyeval"); + let eval_vars_at_ry = poly_vars.evaluate(&ry[1..]); + let blind_eval = random_tape.random_scalar(b"blind_eval"); + let (proof_eval_vars_at_ry, comm_vars_at_ry) = PolyEvalProof::prove( + &poly_vars, + Some(&blinds_vars), + &ry[1..], + &eval_vars_at_ry, + Some(&blind_eval), + &gens.gens_pc, + transcript, + random_tape, + ); + timer_polyeval.stop(); + + // prove the final step of sum-check #2 + let blind_eval_Z_at_ry = (Scalar::one() - ry[0]) * blind_eval; + let blind_expected_claim_postsc2 = claims_phase2[1] * blind_eval_Z_at_ry; + let claim_post_phase2 = claims_phase2[0] * claims_phase2[1]; + let (proof_eq_sc_phase2, _C1, _C2) = EqualityProof::prove( + &gens.gens_pc.gens.gens_1, + transcript, + random_tape, + &claim_post_phase2, + &blind_expected_claim_postsc2, + &claim_post_phase2, + &blind_claim_postsc2, + ); + + timer_prove.stop(); + + ( + R1CSProof { + comm_vars, + sc_proof_phase1, + claims_phase2: ( + comm_Az_claim, + comm_Bz_claim, + comm_Cz_claim, + comm_prod_Az_Bz_claims, + ), + pok_claims_phase2: (pok_Cz_claim, proof_prod), + proof_eq_sc_phase1, + sc_proof_phase2, + comm_vars_at_ry, + proof_eval_vars_at_ry, + proof_eq_sc_phase2, + }, + rx, + ry, + ) + } + + pub fn verify( + &self, + num_vars: usize, + num_cons: usize, + input: &[Scalar], + evals: &(Scalar, Scalar, Scalar), + transcript: &mut Transcript, + gens: &R1CSGens, + ) -> Result<(Vec, Vec), ProofVerifyError> { + transcript.append_protocol_name(R1CSProof::protocol_name()); + + input.append_to_transcript(b"input", transcript); + + let n = num_vars; + // add the commitment to the verifier's transcript + self + .comm_vars + .append_to_transcript(b"poly_commitment", transcript); + + let (num_rounds_x, num_rounds_y) = (num_cons.log_2(), (2 * num_vars).log_2()); + + // derive the verifier's challenge tau + let tau = transcript.challenge_vector(b"challenge_tau", num_rounds_x); + + // verify the first sum-check instance + let claim_phase1 = Scalar::zero() + .commit(&Scalar::zero(), &gens.gens_sc.gens_1) + .compress(); + let (comm_claim_post_phase1, rx) = self.sc_proof_phase1.verify( + &claim_phase1, + num_rounds_x, + 3, + &gens.gens_sc.gens_1, + &gens.gens_sc.gens_4, + transcript, + )?; + // perform the intermediate sum-check test with claimed Az, Bz, and Cz + let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2; + let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2; + + pok_Cz_claim.verify(&gens.gens_sc.gens_1, transcript, comm_Cz_claim)?; + proof_prod.verify( + &gens.gens_sc.gens_1, + transcript, + comm_Az_claim, + comm_Bz_claim, + comm_prod_Az_Bz_claims, + )?; + + comm_Az_claim.append_to_transcript(b"comm_Az_claim", transcript); + comm_Bz_claim.append_to_transcript(b"comm_Bz_claim", transcript); + comm_Cz_claim.append_to_transcript(b"comm_Cz_claim", transcript); + comm_prod_Az_Bz_claims.append_to_transcript(b"comm_prod_Az_Bz_claims", transcript); + + let taus_bound_rx: Scalar = (0..rx.len()) + .map(|i| rx[i] * tau[i] + (Scalar::one() - rx[i]) * (Scalar::one() - tau[i])) + .product(); + let expected_claim_post_phase1 = (taus_bound_rx + * (comm_prod_Az_Bz_claims.decompress().unwrap() - comm_Cz_claim.decompress().unwrap())) + .compress(); + + // verify proof that expected_claim_post_phase1 == claim_post_phase1 + self.proof_eq_sc_phase1.verify( + &gens.gens_sc.gens_1, + transcript, + &expected_claim_post_phase1, + &comm_claim_post_phase1, + )?; + + // derive three public challenges and then derive a joint claim + let r_A = transcript.challenge_scalar(b"challenege_Az"); + let r_B = transcript.challenge_scalar(b"challenege_Bz"); + let r_C = transcript.challenge_scalar(b"challenege_Cz"); + + // r_A * comm_Az_claim + r_B * comm_Bz_claim + r_C * comm_Cz_claim; + let comm_claim_phase2 = GroupElement::vartime_multiscalar_mul( + iter::once(r_A) + .chain(iter::once(r_B)) + .chain(iter::once(r_C)) + .collect(), + iter::once(&comm_Az_claim) + .chain(iter::once(&comm_Bz_claim)) + .chain(iter::once(&comm_Cz_claim)) + .map(|pt| pt.decompress().unwrap()) + .collect(), + ) + .compress(); + + // verify the joint claim with a sum-check protocol + let (comm_claim_post_phase2, ry) = self.sc_proof_phase2.verify( + &comm_claim_phase2, + num_rounds_y, + 2, + &gens.gens_sc.gens_1, + &gens.gens_sc.gens_3, + transcript, + )?; + + // verify Z(ry) proof against the initial commitment + self.proof_eval_vars_at_ry.verify( + &gens.gens_pc, + transcript, + &ry[1..], + &self.comm_vars_at_ry, + &self.comm_vars, + )?; + + let poly_input_eval = { + // constant term + let mut input_as_sparse_poly_entries = vec![SparsePolyEntry::new(0, Scalar::one())]; + //remaining inputs + input_as_sparse_poly_entries.extend( + (0..input.len()) + .map(|i| SparsePolyEntry::new(i + 1, input[i])) + .collect::>(), + ); + SparsePolynomial::new(n.log_2(), input_as_sparse_poly_entries).evaluate(&ry[1..]) + }; + + // compute commitment to eval_Z_at_ry = (Scalar::one() - ry[0]) * self.eval_vars_at_ry + ry[0] * poly_input_eval + let comm_eval_Z_at_ry = GroupElement::vartime_multiscalar_mul( + iter::once(Scalar::one() - ry[0]) + .chain(iter::once(ry[0])) + .map(|s| s) + .collect(), + iter::once(self.comm_vars_at_ry.decompress().unwrap()) + .chain(iter::once( + poly_input_eval.commit(&Scalar::zero(), &gens.gens_pc.gens.gens_1), + )) + .collect(), + ); + + // perform the final check in the second sum-check protocol + let (eval_A_r, eval_B_r, eval_C_r) = evals; + let expected_claim_post_phase2 = + ((r_A * eval_A_r + r_B * eval_B_r + r_C * eval_C_r) * comm_eval_Z_at_ry).compress(); + // verify proof that expected_claim_post_phase1 == claim_post_phase1 + self.proof_eq_sc_phase2.verify( + &gens.gens_sc.gens_1, + transcript, + &expected_claim_post_phase2, + &comm_claim_post_phase2, + )?; + + Ok((rx, ry)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_core::OsRng; + + fn produce_tiny_r1cs() -> (R1CSInstance, Vec, Vec) { + // three constraints over five variables Z1, Z2, Z3, Z4, and Z5 + // rounded to the nearest power of two + let num_cons = 128; + let num_vars = 256; + let num_inputs = 2; + + // encode the above constraints into three matrices + let mut A: Vec<(usize, usize, Scalar)> = Vec::new(); + let mut B: Vec<(usize, usize, Scalar)> = Vec::new(); + let mut C: Vec<(usize, usize, Scalar)> = Vec::new(); + + let one = Scalar::one(); + // constraint 0 entries + // (Z1 + Z2) * I0 - Z3 = 0; + A.push((0, 0, one)); + A.push((0, 1, one)); + B.push((0, num_vars + 1, one)); + C.push((0, 2, one)); + + // constraint 1 entries + // (Z1 + I1) * (Z3) - Z4 = 0 + A.push((1, 0, one)); + A.push((1, num_vars + 2, one)); + B.push((1, 2, one)); + C.push((1, 3, one)); + // constraint 3 entries + // Z5 * 1 - 0 = 0 + A.push((2, 4, one)); + B.push((2, num_vars, one)); + + let inst = R1CSInstance::new(num_cons, num_vars, num_inputs, &A, &B, &C); + + // compute a satisfying assignment + let mut csprng: OsRng = OsRng; + let i0 = Scalar::random(&mut csprng); + let i1 = Scalar::random(&mut csprng); + let z1 = Scalar::random(&mut csprng); + let z2 = Scalar::random(&mut csprng); + let z3 = (z1 + z2) * i0; // constraint 1: (Z1 + Z2) * I0 - Z3 = 0; + let z4 = (z1 + i1) * z3; // constraint 2: (Z1 + I1) * (Z3) - Z4 = 0 + let z5 = Scalar::zero(); //constraint 3 + + let mut vars = vec![Scalar::zero(); num_vars]; + vars[0] = z1; + vars[1] = z2; + vars[2] = z3; + vars[3] = z4; + vars[4] = z5; + + let mut input = vec![Scalar::zero(); num_inputs]; + input[0] = i0; + input[1] = i1; + + (inst, vars, input) + } + + #[test] + fn test_tiny_r1cs() { + let (inst, vars, input) = tests::produce_tiny_r1cs(); + let is_sat = inst.is_sat(&vars, &input); + assert!(is_sat); + } + + #[test] + fn test_synthetic_r1cs() { + let (inst, vars, input) = R1CSInstance::produce_synthetic_r1cs(1024, 1024, 10); + let is_sat = inst.is_sat(&vars, &input); + assert!(is_sat); + } + + #[test] + pub fn check_r1cs_proof() { + let num_vars = 1024; + let num_cons = num_vars; + let num_inputs = 10; + let (inst, vars, input) = R1CSInstance::produce_synthetic_r1cs(num_cons, num_vars, num_inputs); + + let gens = R1CSGens::new(b"test-m", num_cons, num_vars); + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let (proof, rx, ry) = R1CSProof::prove( + &inst, + vars, + &input, + &gens, + &mut prover_transcript, + &mut random_tape, + ); + + let inst_evals = inst.evaluate(&rx, &ry); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify( + inst.get_num_vars(), + inst.get_num_cons(), + &input, + &inst_evals, + &mut verifier_transcript, + &gens, + ) + .is_ok()); + } +} diff --git a/packages/Spartan-secq/src/random.rs b/packages/Spartan-secq/src/random.rs new file mode 100644 index 0000000..a0477e8 --- /dev/null +++ b/packages/Spartan-secq/src/random.rs @@ -0,0 +1,27 @@ +use super::scalar::Scalar; +use super::transcript::ProofTranscript; +use merlin::Transcript; +use rand_core::OsRng; +pub struct RandomTape { + tape: Transcript, +} + +impl RandomTape { + pub fn new(name: &'static [u8]) -> Self { + let tape = { + let mut rng = OsRng::default(); + let mut tape = Transcript::new(name); + tape.append_scalar(b"init_randomness", &Scalar::random(&mut rng)); + tape + }; + Self { tape } + } + + pub fn random_scalar(&mut self, label: &'static [u8]) -> Scalar { + self.tape.challenge_scalar(label) + } + + pub fn random_vector(&mut self, label: &'static [u8], len: usize) -> Vec { + self.tape.challenge_vector(label, len) + } +} diff --git a/packages/Spartan-secq/src/scalar/mod.rs b/packages/Spartan-secq/src/scalar/mod.rs new file mode 100644 index 0000000..e374c2f --- /dev/null +++ b/packages/Spartan-secq/src/scalar/mod.rs @@ -0,0 +1,46 @@ +use secq256k1::elliptic_curve::ops::Reduce; +use secq256k1::U256; + +mod scalar; + +pub type Scalar = scalar::Scalar; +pub type ScalarBytes = secq256k1::Scalar; + +pub trait ScalarFromPrimitives { + fn to_scalar(self) -> Scalar; +} + +impl ScalarFromPrimitives for usize { + #[inline] + fn to_scalar(self) -> Scalar { + (0..self).map(|_i| Scalar::one()).sum() + } +} + +impl ScalarFromPrimitives for bool { + #[inline] + fn to_scalar(self) -> Scalar { + if self { + Scalar::one() + } else { + Scalar::zero() + } + } +} + +pub trait ScalarBytesFromScalar { + fn decompress_scalar(s: &Scalar) -> ScalarBytes; + fn decompress_vector(s: &[Scalar]) -> Vec; +} + +impl ScalarBytesFromScalar for Scalar { + fn decompress_scalar(s: &Scalar) -> ScalarBytes { + ScalarBytes::from_uint_reduced(U256::from_le_slice(&s.to_bytes())) + } + + fn decompress_vector(s: &[Scalar]) -> Vec { + (0..s.len()) + .map(|i| Scalar::decompress_scalar(&s[i])) + .collect::>() + } +} diff --git a/packages/Spartan-secq/src/scalar/scalar.rs b/packages/Spartan-secq/src/scalar/scalar.rs new file mode 100755 index 0000000..95116fc --- /dev/null +++ b/packages/Spartan-secq/src/scalar/scalar.rs @@ -0,0 +1,1266 @@ +//! This module provides an implementation of the secq256k1's scalar field $\mathbb{F}_q$ +//! where `q = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f` +//! This module is an adaptation of code from the bls12-381 crate. +//! We modify various constants (MODULUS, R, R2, etc.) to appropriate values for secq256k1 and update tests +#![allow(clippy::all)] +use core::borrow::Borrow; +use core::convert::TryFrom; +use core::fmt; +use core::iter::{Product, Sum}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use hex_literal::hex; +use num_bigint_dig::{BigUint, ModInverse}; +use rand_core::{CryptoRng, RngCore}; +use serde::de::Visitor; +use serde::{Deserialize, Serialize}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; +use zeroize::Zeroize; + +// use crate::util::{adc, mac, sbb}; +/// Compute a + b + carry, returning the result and the new carry over. +#[inline(always)] +pub const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) { + let ret = (a as u128) + (b as u128) + (carry as u128); + (ret as u64, (ret >> 64) as u64) +} + +/// Compute a - (b + borrow), returning the result and the new borrow. +#[inline(always)] +pub const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) { + let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128)); + (ret as u64, (ret >> 64) as u64) +} + +/// Compute a + (b * c) + carry, returning the result and the new carry over. +#[inline(always)] +pub const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { + let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128); + (ret as u64, (ret >> 64) as u64) +} + +macro_rules! impl_add_binop_specify_output { + ($lhs:ident, $rhs:ident, $output:ident) => { + impl<'b> Add<&'b $rhs> for $lhs { + type Output = $output; + + #[inline] + fn add(self, rhs: &'b $rhs) -> $output { + &self + rhs + } + } + + impl<'a> Add<$rhs> for &'a $lhs { + type Output = $output; + + #[inline] + fn add(self, rhs: $rhs) -> $output { + self + &rhs + } + } + + impl Add<$rhs> for $lhs { + type Output = $output; + + #[inline] + fn add(self, rhs: $rhs) -> $output { + &self + &rhs + } + } + }; +} + +macro_rules! impl_sub_binop_specify_output { + ($lhs:ident, $rhs:ident, $output:ident) => { + impl<'b> Sub<&'b $rhs> for $lhs { + type Output = $output; + + #[inline] + fn sub(self, rhs: &'b $rhs) -> $output { + &self - rhs + } + } + + impl<'a> Sub<$rhs> for &'a $lhs { + type Output = $output; + + #[inline] + fn sub(self, rhs: $rhs) -> $output { + self - &rhs + } + } + + impl Sub<$rhs> for $lhs { + type Output = $output; + + #[inline] + fn sub(self, rhs: $rhs) -> $output { + &self - &rhs + } + } + }; +} + +macro_rules! impl_binops_additive_specify_output { + ($lhs:ident, $rhs:ident, $output:ident) => { + impl_add_binop_specify_output!($lhs, $rhs, $output); + impl_sub_binop_specify_output!($lhs, $rhs, $output); + }; +} + +macro_rules! impl_binops_multiplicative_mixed { + ($lhs:ident, $rhs:ident, $output:ident) => { + impl<'b> Mul<&'b $rhs> for $lhs { + type Output = $output; + + #[inline] + fn mul(self, rhs: &'b $rhs) -> $output { + &self * rhs + } + } + + impl<'a> Mul<$rhs> for &'a $lhs { + type Output = $output; + + #[inline] + fn mul(self, rhs: $rhs) -> $output { + self * &rhs + } + } + + impl Mul<$rhs> for $lhs { + type Output = $output; + + #[inline] + fn mul(self, rhs: $rhs) -> $output { + &self * &rhs + } + } + }; +} + +macro_rules! impl_binops_additive { + ($lhs:ident, $rhs:ident) => { + impl_binops_additive_specify_output!($lhs, $rhs, $lhs); + + impl SubAssign<$rhs> for $lhs { + #[inline] + fn sub_assign(&mut self, rhs: $rhs) { + *self = &*self - &rhs; + } + } + + impl AddAssign<$rhs> for $lhs { + #[inline] + fn add_assign(&mut self, rhs: $rhs) { + *self = &*self + &rhs; + } + } + + impl<'b> SubAssign<&'b $rhs> for $lhs { + #[inline] + fn sub_assign(&mut self, rhs: &'b $rhs) { + *self = &*self - rhs; + } + } + + impl<'b> AddAssign<&'b $rhs> for $lhs { + #[inline] + fn add_assign(&mut self, rhs: &'b $rhs) { + *self = &*self + rhs; + } + } + }; +} + +macro_rules! impl_binops_multiplicative { + ($lhs:ident, $rhs:ident) => { + impl_binops_multiplicative_mixed!($lhs, $rhs, $lhs); + + impl MulAssign<$rhs> for $lhs { + #[inline] + fn mul_assign(&mut self, rhs: $rhs) { + *self = &*self * &rhs; + } + } + + impl<'b> MulAssign<&'b $rhs> for $lhs { + #[inline] + fn mul_assign(&mut self, rhs: &'b $rhs) { + *self = &*self * rhs; + } + } + }; +} + +/// Represents an element of the scalar field $\mathbb{F}_q$ of the secq256k1 elliptic +/// curve construction. +// The internal representation of this type is four 64-bit unsigned +// integers in little-endian order. `Scalar` values are always in +// Montgomery form; i.e., Scalar(a) = aR mod q, with R = 2^256. +#[derive(Clone, Copy, Eq)] +pub struct Scalar(pub(crate) [u64; 5]); + +use serde::ser::SerializeSeq; +use serde::{Deserializer, Serializer}; + +impl Serialize for Scalar { + fn serialize(&self, serializer: S) -> Result { + let values = self.to_bytes(); + + let mut seq = serializer.serialize_seq(Some(values.len()))?; + for val in values.iter() { + seq.serialize_element(val)?; + } + + seq.end() + } +} + +struct U64ArrayVisitor; + +impl<'de> Visitor<'de> for U64ArrayVisitor { + type Value = Scalar; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence of 4 u64 values") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut result = [0u64; 4]; + + for i in 0..4 { + let mut val: u64 = 0; + for j in 0..8 { + val += (seq.next_element::().unwrap().unwrap() as u64) * 256u64.pow(j) + } + result[i] = val; + } + + Ok(Scalar::from_raw(result)) + } +} + +impl<'de> Deserialize<'de> for Scalar { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_seq(U64ArrayVisitor) + } +} + +impl fmt::Debug for Scalar { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let tmp = self.to_bytes(); + write!(f, "0x")?; + for &b in tmp.iter().rev() { + write!(f, "{:02x}", b)?; + } + Ok(()) + } +} + +impl From for Scalar { + fn from(val: u64) -> Scalar { + Scalar([val, 0, 0, 0, 0]) * R2 + } +} + +impl ConstantTimeEq for Scalar { + fn ct_eq(&self, other: &Self) -> Choice { + self.0[0].ct_eq(&other.0[0]) + & self.0[1].ct_eq(&other.0[1]) + & self.0[2].ct_eq(&other.0[2]) + & self.0[3].ct_eq(&other.0[3]) + } +} + +impl PartialEq for Scalar { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).unwrap_u8() == 1 + } +} + +impl ConditionallySelectable for Scalar { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Scalar([ + u64::conditional_select(&a.0[0], &b.0[0], choice), + u64::conditional_select(&a.0[1], &b.0[1], choice), + u64::conditional_select(&a.0[2], &b.0[2], choice), + u64::conditional_select(&a.0[3], &b.0[3], choice), + u64::conditional_select(&a.0[4], &b.0[4], choice), + ]) + } +} + +/// Constant representing the modulus +/// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f +const MODULUS: Scalar = Scalar([ + 0xfffffffefffffc2f, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0, +]); + +impl<'a> Neg for &'a Scalar { + type Output = Scalar; + + #[inline] + fn neg(self) -> Scalar { + self.neg() + } +} + +impl Neg for Scalar { + type Output = Scalar; + + #[inline] + fn neg(self) -> Scalar { + -&self + } +} + +impl<'a, 'b> Sub<&'b Scalar> for &'a Scalar { + type Output = Scalar; + + #[inline] + fn sub(self, rhs: &'b Scalar) -> Scalar { + self.sub(rhs) + } +} + +impl<'a, 'b> Add<&'b Scalar> for &'a Scalar { + type Output = Scalar; + + #[inline] + fn add(self, rhs: &'b Scalar) -> Scalar { + self.add(rhs) + } +} + +impl<'a, 'b> Mul<&'b Scalar> for &'a Scalar { + type Output = Scalar; + + #[inline] + fn mul(self, rhs: &'b Scalar) -> Scalar { + self.mul(rhs) + } +} + +impl_binops_additive!(Scalar, Scalar); +impl_binops_multiplicative!(Scalar, Scalar); + +/// INV = -(q^{-1} mod 2^64) mod 2^64 +const INV: u64 = 0xd838091dd2253531; + +/// R = 2^256 mod q +const R: Scalar = Scalar([ + 0x00000001000003d1, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + 0x0, +]); + +/// R^2 = 2^512 mod q +const R2: Scalar = Scalar([ + 0x000007a2000e90a1, + 0x0000000000000001, + 0x0000000000000000, + 0x0000000000000000, + 0, +]); + +/// R^3 = 2^768 mod q +const R3: Scalar = Scalar([ + 0x002bb1e33795f671, + 0x0000000100000b73, + 0x0000000000000000, + 0x0000000000000000, + 0x0, +]); + +impl Default for Scalar { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl Product for Scalar +where + T: Borrow, +{ + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.fold(Scalar::one(), |acc, item| acc * item.borrow()) + } +} + +impl Sum for Scalar +where + T: Borrow, +{ + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.fold(Scalar::zero(), |acc, item| acc + item.borrow()) + } +} + +impl Zeroize for Scalar { + fn zeroize(&mut self) { + self.0 = [0u64; 5]; + } +} + +impl Scalar { + /// Returns zero, the additive identity. + #[inline] + pub const fn zero() -> Scalar { + Scalar([0, 0, 0, 0, 0]) + } + + /// Returns one, the multiplicative identity. + #[inline] + pub const fn one() -> Scalar { + R + } + + pub fn random(rng: &mut Rng) -> Self { + let mut limbs = [0u64; 8]; + for i in 0..8 { + limbs[i] = rng.next_u64(); + } + Scalar::from_u512(limbs) + } + + /// Doubles this field element. + #[inline] + pub const fn double(&self) -> Scalar { + // TODO: This can be achieved more efficiently with a bitshift. + self.add(self) + } + + /// Attempts to convert a little-endian byte representation of + /// a scalar into a `Scalar`, failing if the input is not canonical. + pub fn from_bytes(bytes: &[u8; 32]) -> CtOption { + let mut tmp = Scalar([0, 0, 0, 0, 0]); + + tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()); + tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); + tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); + tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); + + // Try to subtract the modulus + let (_, borrow) = sbb(tmp.0[0], MODULUS.0[0], 0); + let (_, borrow) = sbb(tmp.0[1], MODULUS.0[1], borrow); + let (_, borrow) = sbb(tmp.0[2], MODULUS.0[2], borrow); + let (_, borrow) = sbb(tmp.0[3], MODULUS.0[3], borrow); + + // If the element is smaller than MODULUS then the + // subtraction will underflow, producing a borrow value + // of 0xffff...ffff. Otherwise, it'll be zero. + let is_some = (borrow as u8) & 1; + + // Convert to Montgomery form by computing + // (a.R^0 * R^2) / R = a.R + tmp *= &R2; + + CtOption::new(tmp, Choice::from(is_some)) + } + + /// Converts an element of `Scalar` into a byte representation in + /// little-endian byte order. + pub fn to_bytes(&self) -> [u8; 32] { + // Turn into canonical form by computing + // (a.R) / R = a + let tmp = Scalar::montgomery_reduce( + self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], 0, 0, 0, 0, + ); + + let mut res = [0; 32]; + res[..8].copy_from_slice(&tmp.0[0].to_le_bytes()); + res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes()); + res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes()); + res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes()); + + res + } + + /// Converts a 512-bit little endian integer into + /// a `Scalar` by reducing by the modulus. + pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar { + Scalar::from_u512([ + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()), + u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[56..64]).unwrap()), + ]) + } + + fn from_u512(limbs: [u64; 8]) -> Scalar { + // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits + // with the higher bits multiplied by 2^256. Thus, we perform two reductions + // + // 1. the lower bits are multiplied by R^2, as normal + // 2. the upper bits are multiplied by R^2 * 2^256 = R^3 + // + // and computing their sum in the field. It remains to see that arbitrary 256-bit + // numbers can be placed into Montgomery form safely using the reduction. The + // reduction works so long as the product is less than R=2^256 multipled by + // the modulus. This holds because for any `c` smaller than the modulus, we have + // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the + // reduction always works so long as `c` is in the field; in this case it is either the + // constant `R2` or `R3`. + let d0 = Scalar([limbs[0], limbs[1], limbs[2], limbs[3], 0]); + let d1 = Scalar([limbs[4], limbs[5], limbs[6], limbs[7], 0]); + // Convert to Montgomery form + d0 * R2 + d1 * R3 + } + + /// Converts from an integer represented in little endian + /// into its (congruent) `Scalar` representation. + pub const fn from_raw(val: [u64; 4]) -> Self { + (&Scalar([val[0], val[1], val[2], val[3], 0])).mul(&R2) + } + + /// Squares this element. + #[inline] + pub const fn square(&self) -> Scalar { + let (r1, carry) = mac(0, self.0[0], self.0[1], 0); + let (r2, carry) = mac(0, self.0[0], self.0[2], carry); + let (r3, r4) = mac(0, self.0[0], self.0[3], carry); + + let (r3, carry) = mac(r3, self.0[1], self.0[2], 0); + let (r4, r5) = mac(r4, self.0[1], self.0[3], carry); + + let (r5, r6) = mac(r5, self.0[2], self.0[3], 0); + + let r7 = r6 >> 63; + let r6 = (r6 << 1) | (r5 >> 63); + let r5 = (r5 << 1) | (r4 >> 63); + let r4 = (r4 << 1) | (r3 >> 63); + let r3 = (r3 << 1) | (r2 >> 63); + let r2 = (r2 << 1) | (r1 >> 63); + let r1 = r1 << 1; + + let (r0, carry) = mac(0, self.0[0], self.0[0], 0); + let (r1, carry) = adc(0, r1, carry); + let (r2, carry) = mac(r2, self.0[1], self.0[1], carry); + let (r3, carry) = adc(0, r3, carry); + let (r4, carry) = mac(r4, self.0[2], self.0[2], carry); + let (r5, carry) = adc(0, r5, carry); + let (r6, carry) = mac(r6, self.0[3], self.0[3], carry); + let (r7, _) = adc(0, r7, carry); + + Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7, 0) + } + + /// Exponentiates `self` by `by`, where `by` is a + /// little-endian order integer exponent. + pub fn pow(&self, by: &[u64; 4]) -> Self { + let mut res = Self::one(); + for e in by.iter().rev() { + for i in (0..64).rev() { + res = res.square(); + let mut tmp = res; + tmp *= self; + res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into()); + } + } + res + } + + /// Exponentiates `self` by `by`, where `by` is a + /// little-endian order integer exponent. + /// + /// **This operation is variable time with respect + /// to the exponent.** If the exponent is fixed, + /// this operation is effectively constant time. + pub fn pow_vartime(&self, by: &[u64; 4]) -> Self { + let mut res = Self::one(); + for e in by.iter().rev() { + for i in (0..64).rev() { + res = res.square(); + + if ((*e >> i) & 1) == 1 { + res.mul_assign(self); + } + } + } + res + } + + pub fn invert(&self) -> CtOption { + let val = BigUint::from_bytes_le(&self.to_bytes()); + + let result = val.mod_inverse(&BigUint::from_bytes_be(&hex!( + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" + ))); + + if result.is_some() { + let mut result = result.unwrap().to_bytes_le().1.to_vec(); + result.resize(64, 0); + + let result_bytes: [u8; 64] = result.try_into().unwrap(); + + let result = Scalar::from_bytes_wide(&result_bytes); + + CtOption::new(result, Choice::from(1)) + } else { + CtOption::new(Scalar::zero(), Choice::from(0)) + } + } + + pub fn batch_invert(inputs: &mut [Scalar]) -> Scalar { + // This code is essentially identical to the FieldElement + // implementation, and is documented there. Unfortunately, + // it's not easy to write it generically, since here we want + // to use `UnpackedScalar`s internally, and `Scalar`s + // externally, but there's no corresponding distinction for + // field elements. + + use zeroize::Zeroizing; + + let n = inputs.len(); + let one = Scalar::one(); + + // Place scratch storage in a Zeroizing wrapper to wipe it when + // we pass out of scope. + let scratch_vec = vec![one; n]; + let mut scratch = Zeroizing::new(scratch_vec); + + // Keep an accumulator of all of the previous products + let mut acc = Scalar::one(); + + // Pass through the input vector, recording the previous + // products in the scratch space + for (input, scratch) in inputs.iter().zip(scratch.iter_mut()) { + *scratch = acc; + + acc = acc * input; + } + + // acc is nonzero iff all inputs are nonzero + debug_assert!(acc != Scalar::zero()); + + // Compute the inverse of all products + acc = acc.invert().unwrap(); + + // We need to return the product of all inverses later + let ret = acc; + + // Pass through the vector backwards to compute the inverses + // in place + for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter().rev()) { + let tmp = &acc * input.clone(); + *input = &acc * scratch; + acc = tmp; + } + + ret + } + + #[inline(always)] + const fn montgomery_reduce( + r0: u64, + r1: u64, + r2: u64, + r3: u64, + r4: u64, + r5: u64, + r6: u64, + r7: u64, + r8: u64, + ) -> Self { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + let k = r0.wrapping_mul(INV); + let (_, carry) = mac(r0, k, MODULUS.0[0], 0); + let (r1, carry) = mac(r1, k, MODULUS.0[1], carry); + let (r2, carry) = mac(r2, k, MODULUS.0[2], carry); + let (r3, carry) = mac(r3, k, MODULUS.0[3], carry); + let (r4, carry) = mac(r4, k, MODULUS.0[4], carry); + let (r5, carry2) = adc(r5, 0, carry); + + let k = r1.wrapping_mul(INV); + let (_, carry) = mac(r1, k, MODULUS.0[0], 0); + let (r2, carry) = mac(r2, k, MODULUS.0[1], carry); + let (r3, carry) = mac(r3, k, MODULUS.0[2], carry); + let (r4, carry) = mac(r4, k, MODULUS.0[3], carry); + let (r5, carry) = mac(r5, k, MODULUS.0[4], carry); + let (r6, carry2) = adc(r6, carry2, carry); + + let k = r2.wrapping_mul(INV); + let (_, carry) = mac(r2, k, MODULUS.0[0], 0); + let (r3, carry) = mac(r3, k, MODULUS.0[1], carry); + let (r4, carry) = mac(r4, k, MODULUS.0[2], carry); + let (r5, carry) = mac(r5, k, MODULUS.0[3], carry); + let (r6, carry) = mac(r6, k, MODULUS.0[4], carry); + let (r7, carry2) = adc(r7, carry2, carry); + + let k = r3.wrapping_mul(INV); + let (_, carry) = mac(r3, k, MODULUS.0[0], 0); + let (r4, carry) = mac(r4, k, MODULUS.0[1], carry); + let (r5, carry) = mac(r5, k, MODULUS.0[2], carry); + let (r6, carry) = mac(r6, k, MODULUS.0[3], carry); + let (r7, carry) = mac(r7, k, MODULUS.0[4], carry); + let (r8, _) = adc(r8, carry2, carry); + + // Result may be within MODULUS of the correct value + (&Scalar([r4, r5, r6, r7, r8])).sub(&MODULUS) + } + + /// Multiplies `rhs` by `self`, returning the result. + #[inline] + pub const fn mul(&self, rhs: &Self) -> Self { + // Schoolbook multiplication + + let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0); + let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry); + let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry); + let (r3, carry) = mac(0, self.0[0], rhs.0[3], carry); + let (r4, r5) = mac(0, self.0[0], rhs.0[4], carry); + + let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0); + let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry); + let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry); + let (r4, carry) = mac(r4, self.0[1], rhs.0[3], carry); + let (r5, r6) = mac(r5, self.0[1], rhs.0[4], carry); + + let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0); + let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry); + let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry); + let (r5, carry) = mac(r5, self.0[2], rhs.0[3], carry); + let (r6, r7) = mac(r6, self.0[2], rhs.0[4], carry); + + let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0); + let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry); + let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry); + let (r6, carry) = mac(r6, self.0[3], rhs.0[3], carry); + let (r7, r8) = mac(r7, self.0[3], rhs.0[4], carry); + + let (r4, carry) = mac(r4, self.0[4], rhs.0[0], 0); + let (r5, carry) = mac(r5, self.0[4], rhs.0[1], carry); + let (r6, carry) = mac(r6, self.0[4], rhs.0[2], carry); + let (r7, carry) = mac(r7, self.0[4], rhs.0[3], carry); + let (r8, _) = mac(r8, self.0[4], rhs.0[4], carry); + + Scalar::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7, r8) + } + + /// Subtracts `rhs` from `self`, returning the result. + #[inline] + pub const fn sub(&self, rhs: &Self) -> Self { + let (d0, borrow) = sbb(self.0[0], rhs.0[0], 0); + let (d1, borrow) = sbb(self.0[1], rhs.0[1], borrow); + let (d2, borrow) = sbb(self.0[2], rhs.0[2], borrow); + let (d3, borrow) = sbb(self.0[3], rhs.0[3], borrow); + let (d4, borrow) = sbb(self.0[4], rhs.0[4], borrow); + + // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise + // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the modulus. + let (d0, carry) = adc(d0, MODULUS.0[0] & borrow, 0); + let (d1, carry) = adc(d1, MODULUS.0[1] & borrow, carry); + let (d2, carry) = adc(d2, MODULUS.0[2] & borrow, carry); + let (d3, carry) = adc(d3, MODULUS.0[3] & borrow, carry); + let (d4, _) = adc(d4, MODULUS.0[4] & borrow, carry); + + Scalar([d0, d1, d2, d3, d4]) + } + + /// Adds `rhs` to `self`, returning the result. + #[inline] + pub const fn add(&self, rhs: &Self) -> Self { + let (d0, carry) = adc(self.0[0], rhs.0[0], 0); + let (d1, carry) = adc(self.0[1], rhs.0[1], carry); + let (d2, carry) = adc(self.0[2], rhs.0[2], carry); + let (d3, carry) = adc(self.0[3], rhs.0[3], carry); + let (d4, _) = adc(self.0[4], rhs.0[4], carry); + + // Attempt to subtract the modulus, to ensure the value + // is smaller than the modulus. + (&Scalar([d0, d1, d2, d3, d4])).sub(&MODULUS) + } + + /// Negates `self`. + #[inline] + pub const fn neg(&self) -> Self { + // Subtract `self` from `MODULUS` to negate. Ignore the final + // borrow because it cannot underflow; self is guaranteed to + // be in the field. + let (d0, borrow) = sbb(MODULUS.0[0], self.0[0], 0); + let (d1, borrow) = sbb(MODULUS.0[1], self.0[1], borrow); + let (d2, borrow) = sbb(MODULUS.0[2], self.0[2], borrow); + let (d3, borrow) = sbb(MODULUS.0[3], self.0[3], borrow); + let (d4, _) = sbb(MODULUS.0[4], self.0[4], borrow); + + // `tmp` could be `MODULUS` if `self` was zero. Create a mask that is + // zero if `self` was zero, and `u64::max_value()` if self was nonzero. + let mask = + (((self.0[0] | self.0[1] | self.0[2] | self.0[3] | self.0[4]) == 0) as u64).wrapping_sub(1); + + Scalar([d0 & mask, d1 & mask, d2 & mask, d3 & mask, d4 & mask]) + } +} + +impl<'a> From<&'a Scalar> for [u8; 32] { + fn from(value: &'a Scalar) -> [u8; 32] { + value.to_bytes() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inv() { + // Compute -(q^{-1} mod 2^64) mod 2^64 by exponentiating + // by totient(2**64) - 1 + + let mut inv = 1u64; + for _ in 0..63 { + inv = inv.wrapping_mul(inv); + inv = inv.wrapping_mul(MODULUS.0[0]); + } + inv = inv.wrapping_neg(); + + assert_eq!(inv, INV); + } + + #[cfg(feature = "std")] + #[test] + fn test_debug() { + assert_eq!( + format!("{:?}", Scalar::zero()), + "0x0000000000000000000000000000000000000000000000000000000000000000" + ); + assert_eq!( + format!("{:?}", Scalar::one()), + "0x0000000000000000000000000000000000000000000000000000000000000001" + ); + assert_eq!( + format!("{:?}", R2), + "0x1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe" + ); + } + + #[test] + fn test_equality() { + assert_eq!(Scalar::zero(), Scalar::zero()); + assert_eq!(Scalar::one(), Scalar::one()); + assert_eq!(R2, R2); + + assert!(Scalar::zero() != Scalar::one()); + assert!(Scalar::one() != R2); + } + + #[test] + fn test_to_bytes() { + assert_eq!( + Scalar::zero().to_bytes(), + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ] + ); + + assert_eq!( + Scalar::one().to_bytes(), + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ] + ); + + /* + assert_eq!( + R2.to_bytes(), + [ + 29, 149, 152, 141, 116, 49, 236, 214, 112, 207, 125, 115, 244, 91, 239, 198, 254, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 15 + ] + ); + + assert_eq!( + (-&Scalar::one()).to_bytes(), + [ + 236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 16 + ] + ); + */ + } + + #[test] + fn test_from_bytes() { + assert_eq!( + Scalar::from_bytes(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ]) + .unwrap(), + Scalar::zero() + ); + + assert_eq!( + Scalar::from_bytes(&[ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ]) + .unwrap(), + Scalar::one() + ); + + assert_eq!( + Scalar::from_bytes(&[ + 209, 3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ]) + .unwrap(), + R2 + ); + + /* + // -1 should work + assert!( + Scalar::from_bytes(&[ + 236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 16 + ]) + .is_some() + .unwrap_u8() + == 1 + ); + + // modulus is invalid + assert!( + Scalar::from_bytes(&[ + 1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216, + 57, 51, 72, 125, 157, 41, 83, 167, 237, 115 + ]) + .is_none() + .unwrap_u8() + == 1 + ); + + // Anything larger than the modulus is invalid + assert!( + Scalar::from_bytes(&[ + 2, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216, + 57, 51, 72, 125, 157, 41, 83, 167, 237, 115 + ]) + .is_none() + .unwrap_u8() + == 1 + ); + assert!( + Scalar::from_bytes(&[ + 1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216, + 58, 51, 72, 125, 157, 41, 83, 167, 237, 115 + ]) + .is_none() + .unwrap_u8() + == 1 + ); + assert!( + Scalar::from_bytes(&[ + 1, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216, + 57, 51, 72, 125, 157, 41, 83, 167, 237, 116 + ]) + .is_none() + .unwrap_u8() + == 1 + ); + */ + } + + #[test] + fn test_from_u512_zero() { + assert_eq!( + Scalar::zero(), + Scalar::from_u512([ + MODULUS.0[0], + MODULUS.0[1], + MODULUS.0[2], + MODULUS.0[3], + 0, + 0, + 0, + 0 + ]) + ); + } + + #[test] + fn test_from_u512_r() { + assert_eq!(R, Scalar::from_u512([1, 0, 0, 0, 0, 0, 0, 0])); + } + + #[test] + fn test_from_u512_r2() { + assert_eq!(R2, Scalar::from_u512([0, 0, 0, 0, 1, 0, 0, 0])); + } + + #[test] + fn test_from_u512_max() { + let max_u64 = 0xffffffffffffffff; + assert_eq!( + R3 - R, + Scalar::from_u512([max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64, max_u64]) + ); + } + + #[test] + fn test_from_bytes_wide_r2() { + assert_eq!( + R2, + Scalar::from_bytes_wide(&[ + 209, 3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]) + ); + } + + #[test] + fn test_from_bytes_wide_negative_one() { + assert_eq!( + -&Scalar::one(), + Scalar::from_bytes_wide(&[ + 46, 252, 255, 255, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]) + ); + } + + #[test] + fn test_from_bytes_wide_maximum() { + assert_eq!( + Scalar::from_raw([0x000007a2000e90a0, 0x1, 0, 0]), + Scalar::from_bytes_wide(&[0xff; 64]) + ); + } + + #[test] + fn test_zero() { + assert_eq!(Scalar::zero(), -&Scalar::zero()); + assert_eq!(Scalar::zero(), Scalar::zero() + Scalar::zero()); + assert_eq!(Scalar::zero(), Scalar::zero() - Scalar::zero()); + assert_eq!(Scalar::zero(), Scalar::zero() * Scalar::zero()); + } + + const LARGEST: Scalar = Scalar([ + 0xfffffffefffffc2e, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0, + ]); + + #[test] + fn test_addition() { + let mut tmp = LARGEST; + tmp += &LARGEST; + + let target = Scalar([ + 0xfffffffefffffc2d, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0, + ]); + + assert_eq!(tmp, target); + + let mut tmp = LARGEST; + tmp += &Scalar([1, 0, 0, 0, 0]); + + assert_eq!(tmp, Scalar::zero()); + } + + #[test] + fn test_negation() { + let tmp = -&LARGEST; + + assert_eq!(tmp, Scalar([1, 0, 0, 0, 0])); + + let tmp = -&Scalar::zero(); + assert_eq!(tmp, Scalar::zero()); + let tmp = -&Scalar([1, 0, 0, 0, 0]); + assert_eq!(tmp, LARGEST); + } + + #[test] + fn test_subtraction() { + let mut tmp = LARGEST; + tmp -= &LARGEST; + + assert_eq!(tmp, Scalar::zero()); + + let mut tmp = Scalar::zero(); + tmp -= &LARGEST; + + let mut tmp2 = MODULUS; + tmp2 -= &LARGEST; + + assert_eq!(tmp, tmp2); + } + + #[test] + fn test_multiplication() { + let mut cur = LARGEST; + + for _ in 0..100 { + let mut tmp = cur; + tmp *= &cur; + + let mut tmp2 = Scalar::zero(); + for b in cur + .to_bytes() + .iter() + .rev() + .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8)) + { + let tmp3 = tmp2; + tmp2.add_assign(&tmp3); + + if b { + tmp2.add_assign(&cur); + } + } + + assert_eq!(tmp, tmp2); + + cur.add_assign(&LARGEST); + } + } + + #[test] + fn test_squaring() { + let mut cur = LARGEST; + + for _ in 0..100 { + let mut tmp = cur; + tmp = tmp.square(); + + let mut tmp2 = Scalar::zero(); + for b in cur + .to_bytes() + .iter() + .rev() + .flat_map(|byte| (0..8).rev().map(move |i| ((byte >> i) & 1u8) == 1u8)) + { + let tmp3 = tmp2; + tmp2.add_assign(&tmp3); + + if b { + tmp2.add_assign(&cur); + } + } + + assert_eq!(tmp, tmp2); + + cur.add_assign(&LARGEST); + } + } + + #[test] + fn test_inversion() { + assert_eq!(Scalar::zero().invert().is_none().unwrap_u8(), 1); + assert_eq!(Scalar::one().invert().unwrap(), Scalar::one()); + assert_eq!((-&Scalar::one()).invert().unwrap(), -&Scalar::one()); + + let a = Scalar::from(123); + let result = a.invert().unwrap(); + println!("result {:?}", result); + + let mut tmp = R2; + + for _ in 0..100 { + let mut tmp2 = tmp.invert().unwrap(); + println!("tmp2 {:?}", tmp2); + tmp2.mul_assign(&tmp); + + assert_eq!(tmp2, Scalar::one()); + + tmp.add_assign(&R2); + } + } + + #[test] + fn test_invert_is_pow() { + let q_minus_2 = [ + 0xffff_fffe_ffff_fc2d, + 0xffff_ffff_ffff_ffff, + 0xffff_ffff_ffff_ffff, + 0xffff_ffff_ffff_ffff, + ]; + + let mut r1 = R; + let mut r2 = R; + let mut r3 = R; + + for _ in 0..100 { + r1 = r1.invert().unwrap(); + r2 = r2.pow_vartime(&q_minus_2); + r3 = r3.pow(&q_minus_2); + + assert_eq!(r1, r2); + assert_eq!(r2, r3); + // Add R so we check something different next time around + r1.add_assign(&R); + r2 = r1; + r3 = r1; + } + } + + #[test] + fn test_from_raw() { + assert_eq!( + Scalar::from_raw([ + 0x00000001000003d0, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ]), + Scalar::from_raw([0xffffffffffffffff; 4]) + ); + + assert_eq!( + Scalar::from_raw(MODULUS.0[..4].try_into().unwrap()), + Scalar::zero() + ); + + assert_eq!(Scalar::from_raw([1, 0, 0, 0]), R); + } + + #[test] + fn test_double() { + let a = Scalar::from_raw([ + 0x1fff3231233ffffd, + 0x4884b7fa00034802, + 0x998c4fefecbc4ff3, + 0x1824b159acc50562, + ]); + + assert_eq!(a.double(), a + a); + } +} diff --git a/packages/Spartan-secq/src/sparse_mlpoly.rs b/packages/Spartan-secq/src/sparse_mlpoly.rs new file mode 100644 index 0000000..b2cf6a1 --- /dev/null +++ b/packages/Spartan-secq/src/sparse_mlpoly.rs @@ -0,0 +1,1679 @@ +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::needless_range_loop)] +use super::dense_mlpoly::DensePolynomial; +use super::dense_mlpoly::{ + EqPolynomial, IdentityPolynomial, PolyCommitment, PolyCommitmentGens, PolyEvalProof, +}; +use super::errors::ProofVerifyError; +use super::math::Math; +use super::product_tree::{DotProductCircuit, ProductCircuit, ProductCircuitEvalProofBatched}; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::timer::Timer; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use core::cmp::Ordering; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct SparseMatEntry { + row: usize, + col: usize, + val: Scalar, +} + +impl SparseMatEntry { + pub fn new(row: usize, col: usize, val: Scalar) -> Self { + SparseMatEntry { row, col, val } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SparseMatPolynomial { + num_vars_x: usize, + num_vars_y: usize, + M: Vec, +} + +pub struct Derefs { + row_ops_val: Vec, + col_ops_val: Vec, + comb: DensePolynomial, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DerefsCommitment { + comm_ops_val: PolyCommitment, +} + +impl Derefs { + pub fn new(row_ops_val: Vec, col_ops_val: Vec) -> Self { + assert_eq!(row_ops_val.len(), col_ops_val.len()); + + let derefs = { + // combine all polynomials into a single polynomial (used below to produce a single commitment) + let comb = DensePolynomial::merge(row_ops_val.iter().chain(col_ops_val.iter())); + + Derefs { + row_ops_val, + col_ops_val, + comb, + } + }; + + derefs + } + + pub fn commit(&self, gens: &PolyCommitmentGens) -> DerefsCommitment { + let (comm_ops_val, _blinds) = self.comb.commit(gens, None); + DerefsCommitment { comm_ops_val } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DerefsEvalProof { + proof_derefs: PolyEvalProof, +} + +impl DerefsEvalProof { + fn protocol_name() -> &'static [u8] { + b"Derefs evaluation proof" + } + + fn prove_single( + joint_poly: &DensePolynomial, + r: &[Scalar], + evals: Vec, + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> PolyEvalProof { + assert_eq!(joint_poly.get_num_vars(), r.len() + evals.len().log_2()); + + // append the claimed evaluations to transcript + evals.append_to_transcript(b"evals_ops_val", transcript); + + // n-to-1 reduction + let (r_joint, eval_joint) = { + let challenges = + transcript.challenge_vector(b"challenge_combine_n_to_one", evals.len().log_2()); + let mut poly_evals = DensePolynomial::new(evals); + for i in (0..challenges.len()).rev() { + poly_evals.bound_poly_var_bot(&challenges[i]); + } + assert_eq!(poly_evals.len(), 1); + let joint_claim_eval = poly_evals[0]; + let mut r_joint = challenges; + r_joint.extend(r); + + debug_assert_eq!(joint_poly.evaluate(&r_joint), joint_claim_eval); + (r_joint, joint_claim_eval) + }; + // decommit the joint polynomial at r_joint + eval_joint.append_to_transcript(b"joint_claim_eval", transcript); + let (proof_derefs, _comm_derefs_eval) = PolyEvalProof::prove( + joint_poly, + None, + &r_joint, + &eval_joint, + None, + gens, + transcript, + random_tape, + ); + + proof_derefs + } + + // evalues both polynomials at r and produces a joint proof of opening + pub fn prove( + derefs: &Derefs, + eval_row_ops_val_vec: &[Scalar], + eval_col_ops_val_vec: &[Scalar], + r: &[Scalar], + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> Self { + transcript.append_protocol_name(DerefsEvalProof::protocol_name()); + + let evals = { + let mut evals = eval_row_ops_val_vec.to_owned(); + evals.extend(eval_col_ops_val_vec); + evals.resize(evals.len().next_power_of_two(), Scalar::zero()); + evals + }; + let proof_derefs = + DerefsEvalProof::prove_single(&derefs.comb, r, evals, gens, transcript, random_tape); + + DerefsEvalProof { proof_derefs } + } + + fn verify_single( + proof: &PolyEvalProof, + comm: &PolyCommitment, + r: &[Scalar], + evals: Vec, + gens: &PolyCommitmentGens, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + // append the claimed evaluations to transcript + evals.append_to_transcript(b"evals_ops_val", transcript); + + // n-to-1 reduction + let challenges = + transcript.challenge_vector(b"challenge_combine_n_to_one", evals.len().log_2()); + let mut poly_evals = DensePolynomial::new(evals); + for i in (0..challenges.len()).rev() { + poly_evals.bound_poly_var_bot(&challenges[i]); + } + assert_eq!(poly_evals.len(), 1); + let joint_claim_eval = poly_evals[0]; + let mut r_joint = challenges; + r_joint.extend(r); + + // decommit the joint polynomial at r_joint + joint_claim_eval.append_to_transcript(b"joint_claim_eval", transcript); + + proof.verify_plain(gens, transcript, &r_joint, &joint_claim_eval, comm) + } + + // verify evaluations of both polynomials at r + pub fn verify( + &self, + r: &[Scalar], + eval_row_ops_val_vec: &[Scalar], + eval_col_ops_val_vec: &[Scalar], + gens: &PolyCommitmentGens, + comm: &DerefsCommitment, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(DerefsEvalProof::protocol_name()); + let mut evals = eval_row_ops_val_vec.to_owned(); + evals.extend(eval_col_ops_val_vec); + evals.resize(evals.len().next_power_of_two(), Scalar::zero()); + + DerefsEvalProof::verify_single( + &self.proof_derefs, + &comm.comm_ops_val, + r, + evals, + gens, + transcript, + ) + } +} + +impl AppendToTranscript for DerefsCommitment { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_message(b"derefs_commitment", b"begin_derefs_commitment"); + self.comm_ops_val.append_to_transcript(label, transcript); + transcript.append_message(b"derefs_commitment", b"end_derefs_commitment"); + } +} + +struct AddrTimestamps { + ops_addr_usize: Vec>, + ops_addr: Vec, + read_ts: Vec, + audit_ts: DensePolynomial, +} + +impl AddrTimestamps { + pub fn new(num_cells: usize, num_ops: usize, ops_addr: Vec>) -> Self { + for item in ops_addr.iter() { + assert_eq!(item.len(), num_ops); + } + + let mut audit_ts = vec![0usize; num_cells]; + let mut ops_addr_vec: Vec = Vec::new(); + let mut read_ts_vec: Vec = Vec::new(); + for ops_addr_inst in ops_addr.iter() { + let mut read_ts = vec![0usize; num_ops]; + + // since read timestamps are trustworthy, we can simply increment the r-ts to obtain a w-ts + // this is sufficient to ensure that the write-set, consisting of (addr, val, ts) tuples, is a set + for i in 0..num_ops { + let addr = ops_addr_inst[i]; + assert!(addr < num_cells); + let r_ts = audit_ts[addr]; + read_ts[i] = r_ts; + + let w_ts = r_ts + 1; + audit_ts[addr] = w_ts; + } + + ops_addr_vec.push(DensePolynomial::from_usize(ops_addr_inst)); + read_ts_vec.push(DensePolynomial::from_usize(&read_ts)); + } + + AddrTimestamps { + ops_addr: ops_addr_vec, + ops_addr_usize: ops_addr, + read_ts: read_ts_vec, + audit_ts: DensePolynomial::from_usize(&audit_ts), + } + } + + fn deref_mem(addr: &[usize], mem_val: &[Scalar]) -> DensePolynomial { + DensePolynomial::new( + (0..addr.len()) + .map(|i| { + let a = addr[i]; + mem_val[a] + }) + .collect::>(), + ) + } + + pub fn deref(&self, mem_val: &[Scalar]) -> Vec { + (0..self.ops_addr.len()) + .map(|i| AddrTimestamps::deref_mem(&self.ops_addr_usize[i], mem_val)) + .collect::>() + } +} + +pub struct MultiSparseMatPolynomialAsDense { + batch_size: usize, + val: Vec, + row: AddrTimestamps, + col: AddrTimestamps, + comb_ops: DensePolynomial, + comb_mem: DensePolynomial, +} + +pub struct SparseMatPolyCommitmentGens { + gens_ops: PolyCommitmentGens, + gens_mem: PolyCommitmentGens, + gens_derefs: PolyCommitmentGens, +} + +impl SparseMatPolyCommitmentGens { + pub fn new( + label: &'static [u8], + num_vars_x: usize, + num_vars_y: usize, + num_nz_entries: usize, + batch_size: usize, + ) -> SparseMatPolyCommitmentGens { + let num_vars_ops = + num_nz_entries.next_power_of_two().log_2() + (batch_size * 5).next_power_of_two().log_2(); + let num_vars_mem = if num_vars_x > num_vars_y { + num_vars_x + } else { + num_vars_y + } + 1; + let num_vars_derefs = + num_nz_entries.next_power_of_two().log_2() + (batch_size * 2).next_power_of_two().log_2(); + + let gens_ops = PolyCommitmentGens::new(num_vars_ops, label); + let gens_mem = PolyCommitmentGens::new(num_vars_mem, label); + let gens_derefs = PolyCommitmentGens::new(num_vars_derefs, label); + SparseMatPolyCommitmentGens { + gens_ops, + gens_mem, + gens_derefs, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SparseMatPolyCommitment { + batch_size: usize, + num_ops: usize, + num_mem_cells: usize, + comm_comb_ops: PolyCommitment, + comm_comb_mem: PolyCommitment, +} + +impl AppendToTranscript for SparseMatPolyCommitment { + fn append_to_transcript(&self, _label: &'static [u8], transcript: &mut Transcript) { + transcript.append_u64(b"batch_size", self.batch_size as u64); + transcript.append_u64(b"num_ops", self.num_ops as u64); + transcript.append_u64(b"num_mem_cells", self.num_mem_cells as u64); + self + .comm_comb_ops + .append_to_transcript(b"comm_comb_ops", transcript); + self + .comm_comb_mem + .append_to_transcript(b"comm_comb_mem", transcript); + } +} + +impl SparseMatPolynomial { + pub fn new(num_vars_x: usize, num_vars_y: usize, M: Vec) -> Self { + SparseMatPolynomial { + num_vars_x, + num_vars_y, + M, + } + } + + pub fn get_num_nz_entries(&self) -> usize { + self.M.len().next_power_of_two() + } + + fn sparse_to_dense_vecs(&self, N: usize) -> (Vec, Vec, Vec) { + assert!(N >= self.get_num_nz_entries()); + let mut ops_row: Vec = vec![0; N]; + let mut ops_col: Vec = vec![0; N]; + let mut val: Vec = vec![Scalar::zero(); N]; + + for i in 0..self.M.len() { + ops_row[i] = self.M[i].row; + ops_col[i] = self.M[i].col; + val[i] = self.M[i].val; + } + (ops_row, ops_col, val) + } + + fn multi_sparse_to_dense_rep( + sparse_polys: &[&SparseMatPolynomial], + ) -> MultiSparseMatPolynomialAsDense { + assert!(!sparse_polys.is_empty()); + for i in 1..sparse_polys.len() { + assert_eq!(sparse_polys[i].num_vars_x, sparse_polys[0].num_vars_x); + assert_eq!(sparse_polys[i].num_vars_y, sparse_polys[0].num_vars_y); + } + + let N = (0..sparse_polys.len()) + .map(|i| sparse_polys[i].get_num_nz_entries()) + .max() + .unwrap() + .next_power_of_two(); + + let mut ops_row_vec: Vec> = Vec::new(); + let mut ops_col_vec: Vec> = Vec::new(); + let mut val_vec: Vec = Vec::new(); + for poly in sparse_polys { + let (ops_row, ops_col, val) = poly.sparse_to_dense_vecs(N); + ops_row_vec.push(ops_row); + ops_col_vec.push(ops_col); + val_vec.push(DensePolynomial::new(val)); + } + + let any_poly = &sparse_polys[0]; + + let num_mem_cells = if any_poly.num_vars_x > any_poly.num_vars_y { + any_poly.num_vars_x.pow2() + } else { + any_poly.num_vars_y.pow2() + }; + + let row = AddrTimestamps::new(num_mem_cells, N, ops_row_vec); + let col = AddrTimestamps::new(num_mem_cells, N, ops_col_vec); + + // combine polynomials into a single polynomial for commitment purposes + let comb_ops = DensePolynomial::merge( + row + .ops_addr + .iter() + .chain(row.read_ts.iter()) + .chain(col.ops_addr.iter()) + .chain(col.read_ts.iter()) + .chain(val_vec.iter()), + ); + let mut comb_mem = row.audit_ts.clone(); + comb_mem.extend(&col.audit_ts); + + MultiSparseMatPolynomialAsDense { + batch_size: sparse_polys.len(), + row, + col, + val: val_vec, + comb_ops, + comb_mem, + } + } + + fn evaluate_with_tables(&self, eval_table_rx: &[Scalar], eval_table_ry: &[Scalar]) -> Scalar { + assert_eq!(self.num_vars_x.pow2(), eval_table_rx.len()); + assert_eq!(self.num_vars_y.pow2(), eval_table_ry.len()); + + (0..self.M.len()) + .map(|i| { + let row = self.M[i].row; + let col = self.M[i].col; + let val = &self.M[i].val; + eval_table_rx[row] * eval_table_ry[col] * val + }) + .sum() + } + + pub fn multi_evaluate( + polys: &[&SparseMatPolynomial], + rx: &[Scalar], + ry: &[Scalar], + ) -> Vec { + let eval_table_rx = EqPolynomial::new(rx.to_vec()).evals(); + let eval_table_ry = EqPolynomial::new(ry.to_vec()).evals(); + + (0..polys.len()) + .map(|i| polys[i].evaluate_with_tables(&eval_table_rx, &eval_table_ry)) + .collect::>() + } + + pub fn multiply_vec(&self, num_rows: usize, num_cols: usize, z: &[Scalar]) -> Vec { + assert_eq!(z.len(), num_cols); + + (0..self.M.len()) + .map(|i| { + let row = self.M[i].row; + let col = self.M[i].col; + let val = &self.M[i].val; + (row, val * z[col]) + }) + .fold(vec![Scalar::zero(); num_rows], |mut Mz, (r, v)| { + Mz[r] += v; + Mz + }) + } + + pub fn compute_eval_table_sparse( + &self, + rx: &[Scalar], + num_rows: usize, + num_cols: usize, + ) -> Vec { + assert_eq!(rx.len(), num_rows); + + let mut M_evals: Vec = vec![Scalar::zero(); num_cols]; + + for i in 0..self.M.len() { + let entry = &self.M[i]; + M_evals[entry.col] += rx[entry.row] * entry.val; + } + M_evals + } + + pub fn multi_commit( + sparse_polys: &[&SparseMatPolynomial], + gens: &SparseMatPolyCommitmentGens, + ) -> (SparseMatPolyCommitment, MultiSparseMatPolynomialAsDense) { + let batch_size = sparse_polys.len(); + let dense = SparseMatPolynomial::multi_sparse_to_dense_rep(sparse_polys); + + let (comm_comb_ops, _blinds_comb_ops) = dense.comb_ops.commit(&gens.gens_ops, None); + let (comm_comb_mem, _blinds_comb_mem) = dense.comb_mem.commit(&gens.gens_mem, None); + + ( + SparseMatPolyCommitment { + batch_size, + num_mem_cells: dense.row.audit_ts.len(), + num_ops: dense.row.read_ts[0].len(), + comm_comb_ops, + comm_comb_mem, + }, + dense, + ) + } +} + +impl MultiSparseMatPolynomialAsDense { + pub fn deref(&self, row_mem_val: &[Scalar], col_mem_val: &[Scalar]) -> Derefs { + let row_ops_val = self.row.deref(row_mem_val); + let col_ops_val = self.col.deref(col_mem_val); + + Derefs::new(row_ops_val, col_ops_val) + } +} + +#[derive(Debug)] +struct ProductLayer { + init: ProductCircuit, + read_vec: Vec, + write_vec: Vec, + audit: ProductCircuit, +} + +#[derive(Debug)] +struct Layers { + prod_layer: ProductLayer, +} + +impl Layers { + fn build_hash_layer( + eval_table: &[Scalar], + addrs_vec: &[DensePolynomial], + derefs_vec: &[DensePolynomial], + read_ts_vec: &[DensePolynomial], + audit_ts: &DensePolynomial, + r_mem_check: &(Scalar, Scalar), + ) -> ( + DensePolynomial, + Vec, + Vec, + DensePolynomial, + ) { + let (r_hash, r_multiset_check) = r_mem_check; + + //hash(addr, val, ts) = ts * r_hash_sqr + val * r_hash + addr + let r_hash_sqr = r_hash * r_hash; + let hash_func = |addr: &Scalar, val: &Scalar, ts: &Scalar| -> Scalar { + ts * r_hash_sqr + val * r_hash + addr + }; + + // hash init and audit that does not depend on #instances + let num_mem_cells = eval_table.len(); + let poly_init_hashed = DensePolynomial::new( + (0..num_mem_cells) + .map(|i| { + // at init time, addr is given by i, init value is given by eval_table, and ts = 0 + hash_func(&Scalar::from(i as u64), &eval_table[i], &Scalar::zero()) - r_multiset_check + }) + .collect::>(), + ); + let poly_audit_hashed = DensePolynomial::new( + (0..num_mem_cells) + .map(|i| { + // at audit time, addr is given by i, value is given by eval_table, and ts is given by audit_ts + hash_func(&Scalar::from(i as u64), &eval_table[i], &audit_ts[i]) - r_multiset_check + }) + .collect::>(), + ); + + // hash read and write that depends on #instances + let mut poly_read_hashed_vec: Vec = Vec::new(); + let mut poly_write_hashed_vec: Vec = Vec::new(); + for i in 0..addrs_vec.len() { + let (addrs, derefs, read_ts) = (&addrs_vec[i], &derefs_vec[i], &read_ts_vec[i]); + assert_eq!(addrs.len(), derefs.len()); + assert_eq!(addrs.len(), read_ts.len()); + let num_ops = addrs.len(); + let poly_read_hashed = DensePolynomial::new( + (0..num_ops) + .map(|i| { + // at read time, addr is given by addrs, value is given by derefs, and ts is given by read_ts + hash_func(&addrs[i], &derefs[i], &read_ts[i]) - r_multiset_check + }) + .collect::>(), + ); + poly_read_hashed_vec.push(poly_read_hashed); + + let poly_write_hashed = DensePolynomial::new( + (0..num_ops) + .map(|i| { + // at write time, addr is given by addrs, value is given by derefs, and ts is given by write_ts = read_ts + 1 + hash_func(&addrs[i], &derefs[i], &(read_ts[i] + Scalar::one())) - r_multiset_check + }) + .collect::>(), + ); + poly_write_hashed_vec.push(poly_write_hashed); + } + + ( + poly_init_hashed, + poly_read_hashed_vec, + poly_write_hashed_vec, + poly_audit_hashed, + ) + } + + pub fn new( + eval_table: &[Scalar], + addr_timestamps: &AddrTimestamps, + poly_ops_val: &[DensePolynomial], + r_mem_check: &(Scalar, Scalar), + ) -> Self { + let (poly_init_hashed, poly_read_hashed_vec, poly_write_hashed_vec, poly_audit_hashed) = + Layers::build_hash_layer( + eval_table, + &addr_timestamps.ops_addr, + poly_ops_val, + &addr_timestamps.read_ts, + &addr_timestamps.audit_ts, + r_mem_check, + ); + + let prod_init = ProductCircuit::new(&poly_init_hashed); + let prod_read_vec = (0..poly_read_hashed_vec.len()) + .map(|i| ProductCircuit::new(&poly_read_hashed_vec[i])) + .collect::>(); + let prod_write_vec = (0..poly_write_hashed_vec.len()) + .map(|i| ProductCircuit::new(&poly_write_hashed_vec[i])) + .collect::>(); + let prod_audit = ProductCircuit::new(&poly_audit_hashed); + + // subset audit check + let hashed_writes: Scalar = (0..prod_write_vec.len()) + .map(|i| prod_write_vec[i].evaluate()) + .product(); + let hashed_write_set: Scalar = prod_init.evaluate() * hashed_writes; + + let hashed_reads: Scalar = (0..prod_read_vec.len()) + .map(|i| prod_read_vec[i].evaluate()) + .product(); + let hashed_read_set: Scalar = hashed_reads * prod_audit.evaluate(); + + //assert_eq!(hashed_read_set, hashed_write_set); + debug_assert_eq!(hashed_read_set, hashed_write_set); + + Layers { + prod_layer: ProductLayer { + init: prod_init, + read_vec: prod_read_vec, + write_vec: prod_write_vec, + audit: prod_audit, + }, + } + } +} + +#[derive(Debug)] +struct PolyEvalNetwork { + row_layers: Layers, + col_layers: Layers, +} + +impl PolyEvalNetwork { + pub fn new( + dense: &MultiSparseMatPolynomialAsDense, + derefs: &Derefs, + mem_rx: &[Scalar], + mem_ry: &[Scalar], + r_mem_check: &(Scalar, Scalar), + ) -> Self { + let row_layers = Layers::new(mem_rx, &dense.row, &derefs.row_ops_val, r_mem_check); + let col_layers = Layers::new(mem_ry, &dense.col, &derefs.col_ops_val, r_mem_check); + + PolyEvalNetwork { + row_layers, + col_layers, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct HashLayerProof { + eval_row: (Vec, Vec, Scalar), + eval_col: (Vec, Vec, Scalar), + eval_val: Vec, + eval_derefs: (Vec, Vec), + proof_ops: PolyEvalProof, + proof_mem: PolyEvalProof, + proof_derefs: DerefsEvalProof, +} + +impl HashLayerProof { + fn protocol_name() -> &'static [u8] { + b"Sparse polynomial hash layer proof" + } + + fn prove_helper( + rand: (&Vec, &Vec), + addr_timestamps: &AddrTimestamps, + ) -> (Vec, Vec, Scalar) { + let (rand_mem, rand_ops) = rand; + + // decommit ops-addr at rand_ops + let mut eval_ops_addr_vec: Vec = Vec::new(); + for i in 0..addr_timestamps.ops_addr.len() { + let eval_ops_addr = addr_timestamps.ops_addr[i].evaluate(rand_ops); + eval_ops_addr_vec.push(eval_ops_addr); + } + + // decommit read_ts at rand_ops + let mut eval_read_ts_vec: Vec = Vec::new(); + for i in 0..addr_timestamps.read_ts.len() { + let eval_read_ts = addr_timestamps.read_ts[i].evaluate(rand_ops); + eval_read_ts_vec.push(eval_read_ts); + } + + // decommit audit-ts at rand_mem + let eval_audit_ts = addr_timestamps.audit_ts.evaluate(rand_mem); + + (eval_ops_addr_vec, eval_read_ts_vec, eval_audit_ts) + } + + fn prove( + rand: (&Vec, &Vec), + dense: &MultiSparseMatPolynomialAsDense, + derefs: &Derefs, + gens: &SparseMatPolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> Self { + transcript.append_protocol_name(HashLayerProof::protocol_name()); + + let (rand_mem, rand_ops) = rand; + + // decommit derefs at rand_ops + let eval_row_ops_val = (0..derefs.row_ops_val.len()) + .map(|i| derefs.row_ops_val[i].evaluate(rand_ops)) + .collect::>(); + let eval_col_ops_val = (0..derefs.col_ops_val.len()) + .map(|i| derefs.col_ops_val[i].evaluate(rand_ops)) + .collect::>(); + let proof_derefs = DerefsEvalProof::prove( + derefs, + &eval_row_ops_val, + &eval_col_ops_val, + rand_ops, + &gens.gens_derefs, + transcript, + random_tape, + ); + let eval_derefs = (eval_row_ops_val, eval_col_ops_val); + + // evaluate row_addr, row_read-ts, col_addr, col_read-ts, val at rand_ops + // evaluate row_audit_ts and col_audit_ts at rand_mem + let (eval_row_addr_vec, eval_row_read_ts_vec, eval_row_audit_ts) = + HashLayerProof::prove_helper((rand_mem, rand_ops), &dense.row); + let (eval_col_addr_vec, eval_col_read_ts_vec, eval_col_audit_ts) = + HashLayerProof::prove_helper((rand_mem, rand_ops), &dense.col); + let eval_val_vec = (0..dense.val.len()) + .map(|i| dense.val[i].evaluate(rand_ops)) + .collect::>(); + + // form a single decommitment using comm_comb_ops + let mut evals_ops: Vec = Vec::new(); + evals_ops.extend(&eval_row_addr_vec); + evals_ops.extend(&eval_row_read_ts_vec); + evals_ops.extend(&eval_col_addr_vec); + evals_ops.extend(&eval_col_read_ts_vec); + evals_ops.extend(&eval_val_vec); + evals_ops.resize(evals_ops.len().next_power_of_two(), Scalar::zero()); + evals_ops.append_to_transcript(b"claim_evals_ops", transcript); + let challenges_ops = + transcript.challenge_vector(b"challenge_combine_n_to_one", evals_ops.len().log_2()); + + let mut poly_evals_ops = DensePolynomial::new(evals_ops); + for i in (0..challenges_ops.len()).rev() { + poly_evals_ops.bound_poly_var_bot(&challenges_ops[i]); + } + assert_eq!(poly_evals_ops.len(), 1); + let joint_claim_eval_ops = poly_evals_ops[0]; + let mut r_joint_ops = challenges_ops; + r_joint_ops.extend(rand_ops); + debug_assert_eq!(dense.comb_ops.evaluate(&r_joint_ops), joint_claim_eval_ops); + joint_claim_eval_ops.append_to_transcript(b"joint_claim_eval_ops", transcript); + let (proof_ops, _comm_ops_eval) = PolyEvalProof::prove( + &dense.comb_ops, + None, + &r_joint_ops, + &joint_claim_eval_ops, + None, + &gens.gens_ops, + transcript, + random_tape, + ); + + // form a single decommitment using comb_comb_mem at rand_mem + let evals_mem: Vec = vec![eval_row_audit_ts, eval_col_audit_ts]; + evals_mem.append_to_transcript(b"claim_evals_mem", transcript); + let challenges_mem = + transcript.challenge_vector(b"challenge_combine_two_to_one", evals_mem.len().log_2()); + + let mut poly_evals_mem = DensePolynomial::new(evals_mem); + for i in (0..challenges_mem.len()).rev() { + poly_evals_mem.bound_poly_var_bot(&challenges_mem[i]); + } + assert_eq!(poly_evals_mem.len(), 1); + let joint_claim_eval_mem = poly_evals_mem[0]; + let mut r_joint_mem = challenges_mem; + r_joint_mem.extend(rand_mem); + debug_assert_eq!(dense.comb_mem.evaluate(&r_joint_mem), joint_claim_eval_mem); + joint_claim_eval_mem.append_to_transcript(b"joint_claim_eval_mem", transcript); + let (proof_mem, _comm_mem_eval) = PolyEvalProof::prove( + &dense.comb_mem, + None, + &r_joint_mem, + &joint_claim_eval_mem, + None, + &gens.gens_mem, + transcript, + random_tape, + ); + + HashLayerProof { + eval_row: (eval_row_addr_vec, eval_row_read_ts_vec, eval_row_audit_ts), + eval_col: (eval_col_addr_vec, eval_col_read_ts_vec, eval_col_audit_ts), + eval_val: eval_val_vec, + eval_derefs, + proof_ops, + proof_mem, + proof_derefs, + } + } + + fn verify_helper( + rand: &(&Vec, &Vec), + claims: &(Scalar, Vec, Vec, Scalar), + eval_ops_val: &[Scalar], + eval_ops_addr: &[Scalar], + eval_read_ts: &[Scalar], + eval_audit_ts: &Scalar, + r: &[Scalar], + r_hash: &Scalar, + r_multiset_check: &Scalar, + ) -> Result<(), ProofVerifyError> { + let r_hash_sqr = r_hash * r_hash; + let hash_func = |addr: &Scalar, val: &Scalar, ts: &Scalar| -> Scalar { + ts * r_hash_sqr + val * r_hash + addr + }; + + let (rand_mem, _rand_ops) = rand; + let (claim_init, claim_read, claim_write, claim_audit) = claims; + + // init + let eval_init_addr = IdentityPolynomial::new(rand_mem.len()).evaluate(rand_mem); + let eval_init_val = EqPolynomial::new(r.to_vec()).evaluate(rand_mem); + let hash_init_at_rand_mem = + hash_func(&eval_init_addr, &eval_init_val, &Scalar::zero()) - r_multiset_check; // verify the claim_last of init chunk + assert_eq!(&hash_init_at_rand_mem, claim_init); + + // read + for i in 0..eval_ops_addr.len() { + let hash_read_at_rand_ops = + hash_func(&eval_ops_addr[i], &eval_ops_val[i], &eval_read_ts[i]) - r_multiset_check; // verify the claim_last of init chunk + assert_eq!(&hash_read_at_rand_ops, &claim_read[i]); + } + + // write: shares addr, val component; only decommit write_ts + for i in 0..eval_ops_addr.len() { + let eval_write_ts = eval_read_ts[i] + Scalar::one(); + let hash_write_at_rand_ops = + hash_func(&eval_ops_addr[i], &eval_ops_val[i], &eval_write_ts) - r_multiset_check; // verify the claim_last of init chunk + assert_eq!(&hash_write_at_rand_ops, &claim_write[i]); + } + + // audit: shares addr and val with init + let eval_audit_addr = eval_init_addr; + let eval_audit_val = eval_init_val; + let hash_audit_at_rand_mem = + hash_func(&eval_audit_addr, &eval_audit_val, eval_audit_ts) - r_multiset_check; + assert_eq!(&hash_audit_at_rand_mem, claim_audit); // verify the last step of the sum-check for audit + + Ok(()) + } + + fn verify( + &self, + rand: (&Vec, &Vec), + claims_row: &(Scalar, Vec, Vec, Scalar), + claims_col: &(Scalar, Vec, Vec, Scalar), + claims_dotp: &[Scalar], + comm: &SparseMatPolyCommitment, + gens: &SparseMatPolyCommitmentGens, + comm_derefs: &DerefsCommitment, + rx: &[Scalar], + ry: &[Scalar], + r_hash: &Scalar, + r_multiset_check: &Scalar, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + let timer = Timer::new("verify_hash_proof"); + transcript.append_protocol_name(HashLayerProof::protocol_name()); + + let (rand_mem, rand_ops) = rand; + + // verify derefs at rand_ops + let (eval_row_ops_val, eval_col_ops_val) = &self.eval_derefs; + assert_eq!(eval_row_ops_val.len(), eval_col_ops_val.len()); + self.proof_derefs.verify( + rand_ops, + eval_row_ops_val, + eval_col_ops_val, + &gens.gens_derefs, + comm_derefs, + transcript, + )?; + + // verify the decommitments used in evaluation sum-check + let eval_val_vec = &self.eval_val; + assert_eq!(claims_dotp.len(), 3 * eval_row_ops_val.len()); + for i in 0..claims_dotp.len() / 3 { + let claim_row_ops_val = claims_dotp[3 * i]; + let claim_col_ops_val = claims_dotp[3 * i + 1]; + let claim_val = claims_dotp[3 * i + 2]; + + assert_eq!(claim_row_ops_val, eval_row_ops_val[i]); + assert_eq!(claim_col_ops_val, eval_col_ops_val[i]); + assert_eq!(claim_val, eval_val_vec[i]); + } + + // verify addr-timestamps using comm_comb_ops at rand_ops + let (eval_row_addr_vec, eval_row_read_ts_vec, eval_row_audit_ts) = &self.eval_row; + let (eval_col_addr_vec, eval_col_read_ts_vec, eval_col_audit_ts) = &self.eval_col; + + let mut evals_ops: Vec = Vec::new(); + evals_ops.extend(eval_row_addr_vec); + evals_ops.extend(eval_row_read_ts_vec); + evals_ops.extend(eval_col_addr_vec); + evals_ops.extend(eval_col_read_ts_vec); + evals_ops.extend(eval_val_vec); + evals_ops.resize(evals_ops.len().next_power_of_two(), Scalar::zero()); + evals_ops.append_to_transcript(b"claim_evals_ops", transcript); + let challenges_ops = + transcript.challenge_vector(b"challenge_combine_n_to_one", evals_ops.len().log_2()); + + let mut poly_evals_ops = DensePolynomial::new(evals_ops); + for i in (0..challenges_ops.len()).rev() { + poly_evals_ops.bound_poly_var_bot(&challenges_ops[i]); + } + assert_eq!(poly_evals_ops.len(), 1); + let joint_claim_eval_ops = poly_evals_ops[0]; + let mut r_joint_ops = challenges_ops; + r_joint_ops.extend(rand_ops); + joint_claim_eval_ops.append_to_transcript(b"joint_claim_eval_ops", transcript); + self.proof_ops.verify_plain( + &gens.gens_ops, + transcript, + &r_joint_ops, + &joint_claim_eval_ops, + &comm.comm_comb_ops, + )?; + + // verify proof-mem using comm_comb_mem at rand_mem + // form a single decommitment using comb_comb_mem at rand_mem + let evals_mem: Vec = vec![*eval_row_audit_ts, *eval_col_audit_ts]; + evals_mem.append_to_transcript(b"claim_evals_mem", transcript); + let challenges_mem = + transcript.challenge_vector(b"challenge_combine_two_to_one", evals_mem.len().log_2()); + + let mut poly_evals_mem = DensePolynomial::new(evals_mem); + for i in (0..challenges_mem.len()).rev() { + poly_evals_mem.bound_poly_var_bot(&challenges_mem[i]); + } + assert_eq!(poly_evals_mem.len(), 1); + let joint_claim_eval_mem = poly_evals_mem[0]; + let mut r_joint_mem = challenges_mem; + r_joint_mem.extend(rand_mem); + joint_claim_eval_mem.append_to_transcript(b"joint_claim_eval_mem", transcript); + self.proof_mem.verify_plain( + &gens.gens_mem, + transcript, + &r_joint_mem, + &joint_claim_eval_mem, + &comm.comm_comb_mem, + )?; + + // verify the claims from the product layer + let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_row; + HashLayerProof::verify_helper( + &(rand_mem, rand_ops), + claims_row, + eval_row_ops_val, + eval_ops_addr, + eval_read_ts, + eval_audit_ts, + rx, + r_hash, + r_multiset_check, + )?; + + let (eval_ops_addr, eval_read_ts, eval_audit_ts) = &self.eval_col; + HashLayerProof::verify_helper( + &(rand_mem, rand_ops), + claims_col, + eval_col_ops_val, + eval_ops_addr, + eval_read_ts, + eval_audit_ts, + ry, + r_hash, + r_multiset_check, + )?; + + timer.stop(); + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct ProductLayerProof { + eval_row: (Scalar, Vec, Vec, Scalar), + eval_col: (Scalar, Vec, Vec, Scalar), + eval_val: (Vec, Vec), + proof_mem: ProductCircuitEvalProofBatched, + proof_ops: ProductCircuitEvalProofBatched, +} + +impl ProductLayerProof { + fn protocol_name() -> &'static [u8] { + b"Sparse polynomial product layer proof" + } + + pub fn prove( + row_prod_layer: &mut ProductLayer, + col_prod_layer: &mut ProductLayer, + dense: &MultiSparseMatPolynomialAsDense, + derefs: &Derefs, + eval: &[Scalar], + transcript: &mut Transcript, + ) -> (Self, Vec, Vec) { + transcript.append_protocol_name(ProductLayerProof::protocol_name()); + + let row_eval_init = row_prod_layer.init.evaluate(); + let row_eval_audit = row_prod_layer.audit.evaluate(); + let row_eval_read = (0..row_prod_layer.read_vec.len()) + .map(|i| row_prod_layer.read_vec[i].evaluate()) + .collect::>(); + let row_eval_write = (0..row_prod_layer.write_vec.len()) + .map(|i| row_prod_layer.write_vec[i].evaluate()) + .collect::>(); + + // subset check + let ws: Scalar = (0..row_eval_write.len()) + .map(|i| row_eval_write[i]) + .product(); + let rs: Scalar = (0..row_eval_read.len()).map(|i| row_eval_read[i]).product(); + assert_eq!(row_eval_init * ws, rs * row_eval_audit); + + row_eval_init.append_to_transcript(b"claim_row_eval_init", transcript); + row_eval_read.append_to_transcript(b"claim_row_eval_read", transcript); + row_eval_write.append_to_transcript(b"claim_row_eval_write", transcript); + row_eval_audit.append_to_transcript(b"claim_row_eval_audit", transcript); + + let col_eval_init = col_prod_layer.init.evaluate(); + let col_eval_audit = col_prod_layer.audit.evaluate(); + let col_eval_read: Vec = (0..col_prod_layer.read_vec.len()) + .map(|i| col_prod_layer.read_vec[i].evaluate()) + .collect(); + let col_eval_write: Vec = (0..col_prod_layer.write_vec.len()) + .map(|i| col_prod_layer.write_vec[i].evaluate()) + .collect(); + + // subset check + let ws: Scalar = (0..col_eval_write.len()) + .map(|i| col_eval_write[i]) + .product(); + let rs: Scalar = (0..col_eval_read.len()).map(|i| col_eval_read[i]).product(); + assert_eq!(col_eval_init * ws, rs * col_eval_audit); + + col_eval_init.append_to_transcript(b"claim_col_eval_init", transcript); + col_eval_read.append_to_transcript(b"claim_col_eval_read", transcript); + col_eval_write.append_to_transcript(b"claim_col_eval_write", transcript); + col_eval_audit.append_to_transcript(b"claim_col_eval_audit", transcript); + + // prepare dotproduct circuit for batching then with ops-related product circuits + assert_eq!(eval.len(), derefs.row_ops_val.len()); + assert_eq!(eval.len(), derefs.col_ops_val.len()); + assert_eq!(eval.len(), dense.val.len()); + let mut dotp_circuit_left_vec: Vec = Vec::new(); + let mut dotp_circuit_right_vec: Vec = Vec::new(); + let mut eval_dotp_left_vec: Vec = Vec::new(); + let mut eval_dotp_right_vec: Vec = Vec::new(); + for i in 0..derefs.row_ops_val.len() { + // evaluate sparse polynomial evaluation using two dotp checks + let left = derefs.row_ops_val[i].clone(); + let right = derefs.col_ops_val[i].clone(); + let weights = dense.val[i].clone(); + + // build two dot product circuits to prove evaluation of sparse polynomial + let mut dotp_circuit = DotProductCircuit::new(left, right, weights); + let (dotp_circuit_left, dotp_circuit_right) = dotp_circuit.split(); + + let (eval_dotp_left, eval_dotp_right) = + (dotp_circuit_left.evaluate(), dotp_circuit_right.evaluate()); + + eval_dotp_left.append_to_transcript(b"claim_eval_dotp_left", transcript); + eval_dotp_right.append_to_transcript(b"claim_eval_dotp_right", transcript); + assert_eq!(eval_dotp_left + eval_dotp_right, eval[i]); + eval_dotp_left_vec.push(eval_dotp_left); + eval_dotp_right_vec.push(eval_dotp_right); + + dotp_circuit_left_vec.push(dotp_circuit_left); + dotp_circuit_right_vec.push(dotp_circuit_right); + } + + // The number of operations into the memory encoded by rx and ry are always the same (by design) + // So we can produce a batched product proof for all of them at the same time. + // prove the correctness of claim_row_eval_read, claim_row_eval_write, claim_col_eval_read, and claim_col_eval_write + // TODO: we currently only produce proofs for 3 batched sparse polynomial evaluations + assert_eq!(row_prod_layer.read_vec.len(), 3); + let (row_read_A, row_read_B, row_read_C) = { + let (vec_A, vec_BC) = row_prod_layer.read_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (row_write_A, row_write_B, row_write_C) = { + let (vec_A, vec_BC) = row_prod_layer.write_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (col_read_A, col_read_B, col_read_C) = { + let (vec_A, vec_BC) = col_prod_layer.read_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (col_write_A, col_write_B, col_write_C) = { + let (vec_A, vec_BC) = col_prod_layer.write_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (dotp_left_A, dotp_left_B, dotp_left_C) = { + let (vec_A, vec_BC) = dotp_circuit_left_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (dotp_right_A, dotp_right_B, dotp_right_C) = { + let (vec_A, vec_BC) = dotp_circuit_right_vec.split_at_mut(1); + let (vec_B, vec_C) = vec_BC.split_at_mut(1); + (vec_A, vec_B, vec_C) + }; + + let (proof_ops, rand_ops) = ProductCircuitEvalProofBatched::prove( + &mut vec![ + &mut row_read_A[0], + &mut row_read_B[0], + &mut row_read_C[0], + &mut row_write_A[0], + &mut row_write_B[0], + &mut row_write_C[0], + &mut col_read_A[0], + &mut col_read_B[0], + &mut col_read_C[0], + &mut col_write_A[0], + &mut col_write_B[0], + &mut col_write_C[0], + ], + &mut vec![ + &mut dotp_left_A[0], + &mut dotp_right_A[0], + &mut dotp_left_B[0], + &mut dotp_right_B[0], + &mut dotp_left_C[0], + &mut dotp_right_C[0], + ], + transcript, + ); + + // produce a batched proof of memory-related product circuits + let (proof_mem, rand_mem) = ProductCircuitEvalProofBatched::prove( + &mut vec![ + &mut row_prod_layer.init, + &mut row_prod_layer.audit, + &mut col_prod_layer.init, + &mut col_prod_layer.audit, + ], + &mut Vec::new(), + transcript, + ); + + let product_layer_proof = ProductLayerProof { + eval_row: (row_eval_init, row_eval_read, row_eval_write, row_eval_audit), + eval_col: (col_eval_init, col_eval_read, col_eval_write, col_eval_audit), + eval_val: (eval_dotp_left_vec, eval_dotp_right_vec), + proof_mem, + proof_ops, + }; + + let product_layer_proof_encoded: Vec = bincode::serialize(&product_layer_proof).unwrap(); + let msg = format!( + "len_product_layer_proof {:?}", + product_layer_proof_encoded.len() + ); + Timer::print(&msg); + + (product_layer_proof, rand_mem, rand_ops) + } + + pub fn verify( + &self, + num_ops: usize, + num_cells: usize, + eval: &[Scalar], + transcript: &mut Transcript, + ) -> Result< + ( + Vec, + Vec, + Vec, + Vec, + Vec, + ), + ProofVerifyError, + > { + transcript.append_protocol_name(ProductLayerProof::protocol_name()); + + let timer = Timer::new("verify_prod_proof"); + let num_instances = eval.len(); + + // subset check + let (row_eval_init, row_eval_read, row_eval_write, row_eval_audit) = &self.eval_row; + assert_eq!(row_eval_write.len(), num_instances); + assert_eq!(row_eval_read.len(), num_instances); + let ws: Scalar = (0..row_eval_write.len()) + .map(|i| row_eval_write[i]) + .product(); + let rs: Scalar = (0..row_eval_read.len()).map(|i| row_eval_read[i]).product(); + assert_eq!(row_eval_init * ws, rs * row_eval_audit); + + row_eval_init.append_to_transcript(b"claim_row_eval_init", transcript); + row_eval_read.append_to_transcript(b"claim_row_eval_read", transcript); + row_eval_write.append_to_transcript(b"claim_row_eval_write", transcript); + row_eval_audit.append_to_transcript(b"claim_row_eval_audit", transcript); + + // subset check + let (col_eval_init, col_eval_read, col_eval_write, col_eval_audit) = &self.eval_col; + assert_eq!(col_eval_write.len(), num_instances); + assert_eq!(col_eval_read.len(), num_instances); + let ws: Scalar = (0..col_eval_write.len()) + .map(|i| col_eval_write[i]) + .product(); + let rs: Scalar = (0..col_eval_read.len()).map(|i| col_eval_read[i]).product(); + assert_eq!(col_eval_init * ws, rs * col_eval_audit); + + col_eval_init.append_to_transcript(b"claim_col_eval_init", transcript); + col_eval_read.append_to_transcript(b"claim_col_eval_read", transcript); + col_eval_write.append_to_transcript(b"claim_col_eval_write", transcript); + col_eval_audit.append_to_transcript(b"claim_col_eval_audit", transcript); + + // verify the evaluation of the sparse polynomial + let (eval_dotp_left, eval_dotp_right) = &self.eval_val; + assert_eq!(eval_dotp_left.len(), eval_dotp_left.len()); + assert_eq!(eval_dotp_left.len(), num_instances); + let mut claims_dotp_circuit: Vec = Vec::new(); + for i in 0..num_instances { + assert_eq!(eval_dotp_left[i] + eval_dotp_right[i], eval[i]); + eval_dotp_left[i].append_to_transcript(b"claim_eval_dotp_left", transcript); + eval_dotp_right[i].append_to_transcript(b"claim_eval_dotp_right", transcript); + + claims_dotp_circuit.push(eval_dotp_left[i]); + claims_dotp_circuit.push(eval_dotp_right[i]); + } + + // verify the correctness of claim_row_eval_read, claim_row_eval_write, claim_col_eval_read, and claim_col_eval_write + let mut claims_prod_circuit: Vec = Vec::new(); + claims_prod_circuit.extend(row_eval_read); + claims_prod_circuit.extend(row_eval_write); + claims_prod_circuit.extend(col_eval_read); + claims_prod_circuit.extend(col_eval_write); + + let (claims_ops, claims_dotp, rand_ops) = self.proof_ops.verify( + &claims_prod_circuit, + &claims_dotp_circuit, + num_ops, + transcript, + ); + // verify the correctness of claim_row_eval_init and claim_row_eval_audit + let (claims_mem, _claims_mem_dotp, rand_mem) = self.proof_mem.verify( + &[ + *row_eval_init, + *row_eval_audit, + *col_eval_init, + *col_eval_audit, + ], + &Vec::new(), + num_cells, + transcript, + ); + timer.stop(); + + Ok((claims_mem, rand_mem, claims_ops, claims_dotp, rand_ops)) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct PolyEvalNetworkProof { + proof_prod_layer: ProductLayerProof, + proof_hash_layer: HashLayerProof, +} + +impl PolyEvalNetworkProof { + fn protocol_name() -> &'static [u8] { + b"Sparse polynomial evaluation proof" + } + + pub fn prove( + network: &mut PolyEvalNetwork, + dense: &MultiSparseMatPolynomialAsDense, + derefs: &Derefs, + evals: &[Scalar], + gens: &SparseMatPolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> Self { + transcript.append_protocol_name(PolyEvalNetworkProof::protocol_name()); + + let (proof_prod_layer, rand_mem, rand_ops) = ProductLayerProof::prove( + &mut network.row_layers.prod_layer, + &mut network.col_layers.prod_layer, + dense, + derefs, + evals, + transcript, + ); + + // proof of hash layer for row and col + let proof_hash_layer = HashLayerProof::prove( + (&rand_mem, &rand_ops), + dense, + derefs, + gens, + transcript, + random_tape, + ); + + PolyEvalNetworkProof { + proof_prod_layer, + proof_hash_layer, + } + } + + pub fn verify( + &self, + comm: &SparseMatPolyCommitment, + comm_derefs: &DerefsCommitment, + evals: &[Scalar], + gens: &SparseMatPolyCommitmentGens, + rx: &[Scalar], + ry: &[Scalar], + r_mem_check: &(Scalar, Scalar), + nz: usize, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + let timer = Timer::new("verify_polyeval_proof"); + transcript.append_protocol_name(PolyEvalNetworkProof::protocol_name()); + + let num_instances = evals.len(); + let (r_hash, r_multiset_check) = r_mem_check; + + let num_ops = nz.next_power_of_two(); + let num_cells = rx.len().pow2(); + assert_eq!(rx.len(), ry.len()); + + let (claims_mem, rand_mem, mut claims_ops, claims_dotp, rand_ops) = self + .proof_prod_layer + .verify(num_ops, num_cells, evals, transcript)?; + assert_eq!(claims_mem.len(), 4); + assert_eq!(claims_ops.len(), 4 * num_instances); + assert_eq!(claims_dotp.len(), 3 * num_instances); + + let (claims_ops_row, claims_ops_col) = claims_ops.split_at_mut(2 * num_instances); + let (claims_ops_row_read, claims_ops_row_write) = claims_ops_row.split_at_mut(num_instances); + let (claims_ops_col_read, claims_ops_col_write) = claims_ops_col.split_at_mut(num_instances); + + // verify the proof of hash layer + self.proof_hash_layer.verify( + (&rand_mem, &rand_ops), + &( + claims_mem[0], + claims_ops_row_read.to_vec(), + claims_ops_row_write.to_vec(), + claims_mem[1], + ), + &( + claims_mem[2], + claims_ops_col_read.to_vec(), + claims_ops_col_write.to_vec(), + claims_mem[3], + ), + &claims_dotp, + comm, + gens, + comm_derefs, + rx, + ry, + r_hash, + r_multiset_check, + transcript, + )?; + timer.stop(); + + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SparseMatPolyEvalProof { + comm_derefs: DerefsCommitment, + poly_eval_network_proof: PolyEvalNetworkProof, +} + +impl SparseMatPolyEvalProof { + fn protocol_name() -> &'static [u8] { + b"Sparse polynomial evaluation proof" + } + + fn equalize(rx: &[Scalar], ry: &[Scalar]) -> (Vec, Vec) { + match rx.len().cmp(&ry.len()) { + Ordering::Less => { + let diff = ry.len() - rx.len(); + let mut rx_ext = vec![Scalar::zero(); diff]; + rx_ext.extend(rx); + (rx_ext, ry.to_vec()) + } + Ordering::Greater => { + let diff = rx.len() - ry.len(); + let mut ry_ext = vec![Scalar::zero(); diff]; + ry_ext.extend(ry); + (rx.to_vec(), ry_ext) + } + Ordering::Equal => (rx.to_vec(), ry.to_vec()), + } + } + + pub fn prove( + dense: &MultiSparseMatPolynomialAsDense, + rx: &[Scalar], // point at which the polynomial is evaluated + ry: &[Scalar], + evals: &[Scalar], // a vector evaluation of \widetilde{M}(r = (rx,ry)) for each M + gens: &SparseMatPolyCommitmentGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> SparseMatPolyEvalProof { + transcript.append_protocol_name(SparseMatPolyEvalProof::protocol_name()); + + // ensure there is one eval for each polynomial in dense + assert_eq!(evals.len(), dense.batch_size); + + let (mem_rx, mem_ry) = { + // equalize the lengths of rx and ry + let (rx_ext, ry_ext) = SparseMatPolyEvalProof::equalize(rx, ry); + let poly_rx = EqPolynomial::new(rx_ext).evals(); + let poly_ry = EqPolynomial::new(ry_ext).evals(); + (poly_rx, poly_ry) + }; + + let derefs = dense.deref(&mem_rx, &mem_ry); + + // commit to non-deterministic choices of the prover + let timer_commit = Timer::new("commit_nondet_witness"); + let comm_derefs = { + let comm = derefs.commit(&gens.gens_derefs); + comm.append_to_transcript(b"comm_poly_row_col_ops_val", transcript); + comm + }; + timer_commit.stop(); + + let poly_eval_network_proof = { + // produce a random element from the transcript for hash function + let r_mem_check = transcript.challenge_vector(b"challenge_r_hash", 2); + + // build a network to evaluate the sparse polynomial + let timer_build_network = Timer::new("build_layered_network"); + let mut net = PolyEvalNetwork::new( + dense, + &derefs, + &mem_rx, + &mem_ry, + &(r_mem_check[0], r_mem_check[1]), + ); + timer_build_network.stop(); + + let timer_eval_network = Timer::new("evalproof_layered_network"); + let poly_eval_network_proof = PolyEvalNetworkProof::prove( + &mut net, + dense, + &derefs, + evals, + gens, + transcript, + random_tape, + ); + timer_eval_network.stop(); + + poly_eval_network_proof + }; + + SparseMatPolyEvalProof { + comm_derefs, + poly_eval_network_proof, + } + } + + pub fn verify( + &self, + comm: &SparseMatPolyCommitment, + rx: &[Scalar], // point at which the polynomial is evaluated + ry: &[Scalar], + evals: &[Scalar], // evaluation of \widetilde{M}(r = (rx,ry)) + gens: &SparseMatPolyCommitmentGens, + transcript: &mut Transcript, + ) -> Result<(), ProofVerifyError> { + transcript.append_protocol_name(SparseMatPolyEvalProof::protocol_name()); + + // equalize the lengths of rx and ry + let (rx_ext, ry_ext) = SparseMatPolyEvalProof::equalize(rx, ry); + + let (nz, num_mem_cells) = (comm.num_ops, comm.num_mem_cells); + assert_eq!(rx_ext.len().pow2(), num_mem_cells); + + // add claims to transcript and obtain challenges for randomized mem-check circuit + self + .comm_derefs + .append_to_transcript(b"comm_poly_row_col_ops_val", transcript); + + // produce a random element from the transcript for hash function + let r_mem_check = transcript.challenge_vector(b"challenge_r_hash", 2); + + self.poly_eval_network_proof.verify( + comm, + &self.comm_derefs, + evals, + gens, + &rx_ext, + &ry_ext, + &(r_mem_check[0], r_mem_check[1]), + nz, + transcript, + ) + } +} + +pub struct SparsePolyEntry { + idx: usize, + val: Scalar, +} + +impl SparsePolyEntry { + pub fn new(idx: usize, val: Scalar) -> Self { + SparsePolyEntry { idx, val } + } +} + +pub struct SparsePolynomial { + num_vars: usize, + Z: Vec, +} + +impl SparsePolynomial { + pub fn new(num_vars: usize, Z: Vec) -> Self { + SparsePolynomial { num_vars, Z } + } + + fn compute_chi(a: &[bool], r: &[Scalar]) -> Scalar { + assert_eq!(a.len(), r.len()); + let mut chi_i = Scalar::one(); + for j in 0..r.len() { + if a[j] { + chi_i *= r[j]; + } else { + chi_i *= Scalar::one() - r[j]; + } + } + chi_i + } + + // Takes O(n log n). TODO: do this in O(n) where n is the number of entries in Z + pub fn evaluate(&self, r: &[Scalar]) -> Scalar { + assert_eq!(self.num_vars, r.len()); + + (0..self.Z.len()) + .map(|i| { + let bits = self.Z[i].idx.get_bits(r.len()); + SparsePolynomial::compute_chi(&bits, r) * self.Z[i].val + }) + .sum() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand_core::{RngCore, OsRng}; + #[test] + fn check_sparse_polyeval_proof() { + let mut csprng: OsRng = OsRng; + + let num_nz_entries: usize = 256; + let num_rows: usize = 256; + let num_cols: usize = 256; + let num_vars_x: usize = num_rows.log_2(); + let num_vars_y: usize = num_cols.log_2(); + + let mut M: Vec = Vec::new(); + + for _i in 0..num_nz_entries { + M.push(SparseMatEntry::new( + (csprng.next_u64() % (num_rows as u64)) as usize, + (csprng.next_u64() % (num_cols as u64)) as usize, + Scalar::random(&mut csprng), + )); + } + + let poly_M = SparseMatPolynomial::new(num_vars_x, num_vars_y, M); + let gens = SparseMatPolyCommitmentGens::new( + b"gens_sparse_poly", + num_vars_x, + num_vars_y, + num_nz_entries, + 3, + ); + + // commitment + let (poly_comm, dense) = SparseMatPolynomial::multi_commit(&[&poly_M, &poly_M, &poly_M], &gens); + + // evaluation + let rx: Vec = (0..num_vars_x) + .map(|_i| Scalar::random(&mut csprng)) + .collect::>(); + let ry: Vec = (0..num_vars_y) + .map(|_i| Scalar::random(&mut csprng)) + .collect::>(); + let eval = SparseMatPolynomial::multi_evaluate(&[&poly_M], &rx, &ry); + let evals = vec![eval[0], eval[0], eval[0]]; + + let mut random_tape = RandomTape::new(b"proof"); + let mut prover_transcript = Transcript::new(b"example"); + let proof = SparseMatPolyEvalProof::prove( + &dense, + &rx, + &ry, + &evals, + &gens, + &mut prover_transcript, + &mut random_tape, + ); + + let mut verifier_transcript = Transcript::new(b"example"); + assert!(proof + .verify( + &poly_comm, + &rx, + &ry, + &evals, + &gens, + &mut verifier_transcript, + ) + .is_ok()); + } +} diff --git a/packages/Spartan-secq/src/sumcheck.rs b/packages/Spartan-secq/src/sumcheck.rs new file mode 100644 index 0000000..418e3a8 --- /dev/null +++ b/packages/Spartan-secq/src/sumcheck.rs @@ -0,0 +1,778 @@ +#![allow(clippy::too_many_arguments)] +#![allow(clippy::type_complexity)] +use super::commitments::{Commitments, MultiCommitGens}; +use super::dense_mlpoly::DensePolynomial; +use super::errors::ProofVerifyError; +use super::group::{CompressedGroup, GroupElement, VartimeMultiscalarMul}; +use super::nizk::DotProductProof; +use super::random::RandomTape; +use super::scalar::Scalar; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use super::unipoly::{CompressedUniPoly, UniPoly}; +use crate::group::DecompressEncodedPoint; +use core::iter; +use itertools::izip; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct SumcheckInstanceProof { + compressed_polys: Vec, +} + +impl SumcheckInstanceProof { + pub fn new(compressed_polys: Vec) -> SumcheckInstanceProof { + SumcheckInstanceProof { compressed_polys } + } + + pub fn verify( + &self, + claim: Scalar, + num_rounds: usize, + degree_bound: usize, + transcript: &mut Transcript, + ) -> Result<(Scalar, Vec), ProofVerifyError> { + let mut e = claim; + let mut r: Vec = Vec::new(); + + // verify that there is a univariate polynomial for each round + assert_eq!(self.compressed_polys.len(), num_rounds); + for i in 0..self.compressed_polys.len() { + let poly = self.compressed_polys[i].decompress(&e); + + // verify degree bound + assert_eq!(poly.degree(), degree_bound); + + // check if G_k(0) + G_k(1) = e + assert_eq!(poly.eval_at_zero() + poly.eval_at_one(), e); + + // append the prover's message to the transcript + poly.append_to_transcript(b"poly", transcript); + + //derive the verifier's challenge for the next round + let r_i = transcript.challenge_scalar(b"challenge_nextround"); + + r.push(r_i); + + // evaluate the claimed degree-ell polynomial at r_i + e = poly.evaluate(&r_i); + } + + Ok((e, r)) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ZKSumcheckInstanceProof { + comm_polys: Vec, + comm_evals: Vec, + proofs: Vec, +} + +impl ZKSumcheckInstanceProof { + pub fn new( + comm_polys: Vec, + comm_evals: Vec, + proofs: Vec, + ) -> Self { + ZKSumcheckInstanceProof { + comm_polys, + comm_evals, + proofs, + } + } + + pub fn verify( + &self, + comm_claim: &CompressedGroup, + num_rounds: usize, + degree_bound: usize, + gens_1: &MultiCommitGens, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + ) -> Result<(CompressedGroup, Vec), ProofVerifyError> { + // verify degree bound + assert_eq!(gens_n.n, degree_bound + 1); + + // verify that there is a univariate polynomial for each round + assert_eq!(self.comm_polys.len(), num_rounds); + assert_eq!(self.comm_evals.len(), num_rounds); + + let mut r: Vec = Vec::new(); + for i in 0..self.comm_polys.len() { + let comm_poly = &self.comm_polys[i]; + + // append the prover's polynomial to the transcript + comm_poly.append_to_transcript(b"comm_poly", transcript); + + //derive the verifier's challenge for the next round + let r_i = transcript.challenge_scalar(b"challenge_nextround"); + + // verify the proof of sum-check and evals + let res = { + let comm_claim_per_round = if i == 0 { + comm_claim + } else { + &self.comm_evals[i - 1] + }; + let comm_eval = &self.comm_evals[i]; + + // add two claims to transcript + comm_claim_per_round.append_to_transcript(b"comm_claim_per_round", transcript); + comm_eval.append_to_transcript(b"comm_eval", transcript); + + // produce two weights + let w = transcript.challenge_vector(b"combine_two_claims_to_one", 2); + + // compute a weighted sum of the RHS + let comm_target = GroupElement::vartime_multiscalar_mul( + w.clone(), + iter::once(&comm_claim_per_round) + .chain(iter::once(&comm_eval)) + .map(|pt| pt.decompress().unwrap()) + .collect(), + ) + .compress(); + + let a = { + // the vector to use to decommit for sum-check test + let a_sc = { + let mut a = vec![Scalar::one(); degree_bound + 1]; + a[0] += Scalar::one(); + a + }; + + // the vector to use to decommit for evaluation + let a_eval = { + let mut a = vec![Scalar::one(); degree_bound + 1]; + for j in 1..a.len() { + a[j] = a[j - 1] * r_i; + } + a + }; + + // take weighted sum of the two vectors using w + assert_eq!(a_sc.len(), a_eval.len()); + (0..a_sc.len()) + .map(|i| w[0] * a_sc[i] + w[1] * a_eval[i]) + .collect::>() + }; + + self.proofs[i] + .verify( + gens_1, + gens_n, + transcript, + &a, + &self.comm_polys[i], + &comm_target, + ) + .is_ok() + }; + if !res { + return Err(ProofVerifyError::InternalError); + } + + r.push(r_i); + } + + Ok((self.comm_evals[self.comm_evals.len() - 1], r)) + } +} + +impl SumcheckInstanceProof { + pub fn prove_cubic( + claim: &Scalar, + num_rounds: usize, + poly_A: &mut DensePolynomial, + poly_B: &mut DensePolynomial, + poly_C: &mut DensePolynomial, + comb_func: F, + transcript: &mut Transcript, + ) -> (Self, Vec, Vec) + where + F: Fn(&Scalar, &Scalar, &Scalar) -> Scalar, + { + let mut e = *claim; + let mut r: Vec = Vec::new(); + let mut cubic_polys: Vec = Vec::new(); + for _j in 0..num_rounds { + let mut eval_point_0 = Scalar::zero(); + let mut eval_point_2 = Scalar::zero(); + let mut eval_point_3 = Scalar::zero(); + + let len = poly_A.len() / 2; + for i in 0..len { + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + } + + let evals = vec![eval_point_0, e - eval_point_0, eval_point_2, eval_point_3]; + let poly = UniPoly::from_evals(&evals); + + // append the prover's message to the transcript + poly.append_to_transcript(b"poly", transcript); + + //derive the verifier's challenge for the next round + let r_j = transcript.challenge_scalar(b"challenge_nextround"); + r.push(r_j); + // bound all tables to the verifier's challenege + poly_A.bound_poly_var_top(&r_j); + poly_B.bound_poly_var_top(&r_j); + poly_C.bound_poly_var_top(&r_j); + e = poly.evaluate(&r_j); + cubic_polys.push(poly.compress()); + } + + ( + SumcheckInstanceProof::new(cubic_polys), + r, + vec![poly_A[0], poly_B[0], poly_C[0]], + ) + } + + pub fn prove_cubic_batched( + claim: &Scalar, + num_rounds: usize, + poly_vec_par: ( + &mut Vec<&mut DensePolynomial>, + &mut Vec<&mut DensePolynomial>, + &mut DensePolynomial, + ), + poly_vec_seq: ( + &mut Vec<&mut DensePolynomial>, + &mut Vec<&mut DensePolynomial>, + &mut Vec<&mut DensePolynomial>, + ), + coeffs: &[Scalar], + comb_func: F, + transcript: &mut Transcript, + ) -> ( + Self, + Vec, + (Vec, Vec, Scalar), + (Vec, Vec, Vec), + ) + where + F: Fn(&Scalar, &Scalar, &Scalar) -> Scalar, + { + let (poly_A_vec_par, poly_B_vec_par, poly_C_par) = poly_vec_par; + let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq; + + //let (poly_A_vec_seq, poly_B_vec_seq, poly_C_vec_seq) = poly_vec_seq; + let mut e = *claim; + let mut r: Vec = Vec::new(); + let mut cubic_polys: Vec = Vec::new(); + + for _j in 0..num_rounds { + let mut evals: Vec<(Scalar, Scalar, Scalar)> = Vec::new(); + + for (poly_A, poly_B) in poly_A_vec_par.iter().zip(poly_B_vec_par.iter()) { + let mut eval_point_0 = Scalar::zero(); + let mut eval_point_2 = Scalar::zero(); + let mut eval_point_3 = Scalar::zero(); + + let len = poly_A.len() / 2; + for i in 0..len { + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C_par[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_par[len + i] + poly_C_par[len + i] - poly_C_par[i]; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C_par[len + i] - poly_C_par[i]; + + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + } + + evals.push((eval_point_0, eval_point_2, eval_point_3)); + } + + for (poly_A, poly_B, poly_C) in izip!( + poly_A_vec_seq.iter(), + poly_B_vec_seq.iter(), + poly_C_vec_seq.iter() + ) { + let mut eval_point_0 = Scalar::zero(); + let mut eval_point_2 = Scalar::zero(); + let mut eval_point_3 = Scalar::zero(); + let len = poly_A.len() / 2; + for i in 0..len { + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i]); + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + } + evals.push((eval_point_0, eval_point_2, eval_point_3)); + } + + let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * coeffs[i]).sum(); + let evals_combined_3 = (0..evals.len()).map(|i| evals[i].2 * coeffs[i]).sum(); + + let evals = vec![ + evals_combined_0, + e - evals_combined_0, + evals_combined_2, + evals_combined_3, + ]; + let poly = UniPoly::from_evals(&evals); + + // append the prover's message to the transcript + poly.append_to_transcript(b"poly", transcript); + + //derive the verifier's challenge for the next round + let r_j = transcript.challenge_scalar(b"challenge_nextround"); + r.push(r_j); + + // bound all tables to the verifier's challenege + for (poly_A, poly_B) in poly_A_vec_par.iter_mut().zip(poly_B_vec_par.iter_mut()) { + poly_A.bound_poly_var_top(&r_j); + poly_B.bound_poly_var_top(&r_j); + } + poly_C_par.bound_poly_var_top(&r_j); + + for (poly_A, poly_B, poly_C) in izip!( + poly_A_vec_seq.iter_mut(), + poly_B_vec_seq.iter_mut(), + poly_C_vec_seq.iter_mut() + ) { + poly_A.bound_poly_var_top(&r_j); + poly_B.bound_poly_var_top(&r_j); + poly_C.bound_poly_var_top(&r_j); + } + + e = poly.evaluate(&r_j); + cubic_polys.push(poly.compress()); + } + + let poly_A_par_final = (0..poly_A_vec_par.len()) + .map(|i| poly_A_vec_par[i][0]) + .collect(); + let poly_B_par_final = (0..poly_B_vec_par.len()) + .map(|i| poly_B_vec_par[i][0]) + .collect(); + let claims_prod = (poly_A_par_final, poly_B_par_final, poly_C_par[0]); + + let poly_A_seq_final = (0..poly_A_vec_seq.len()) + .map(|i| poly_A_vec_seq[i][0]) + .collect(); + let poly_B_seq_final = (0..poly_B_vec_seq.len()) + .map(|i| poly_B_vec_seq[i][0]) + .collect(); + let poly_C_seq_final = (0..poly_C_vec_seq.len()) + .map(|i| poly_C_vec_seq[i][0]) + .collect(); + let claims_dotp = (poly_A_seq_final, poly_B_seq_final, poly_C_seq_final); + + ( + SumcheckInstanceProof::new(cubic_polys), + r, + claims_prod, + claims_dotp, + ) + } +} + +impl ZKSumcheckInstanceProof { + pub fn prove_quad( + claim: &Scalar, + blind_claim: &Scalar, + num_rounds: usize, + poly_A: &mut DensePolynomial, + poly_B: &mut DensePolynomial, + comb_func: F, + gens_1: &MultiCommitGens, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (Self, Vec, Vec, Scalar) + where + F: Fn(&Scalar, &Scalar) -> Scalar, + { + let (blinds_poly, blinds_evals) = ( + random_tape.random_vector(b"blinds_poly", num_rounds), + random_tape.random_vector(b"blinds_evals", num_rounds), + ); + let mut claim_per_round = *claim; + let mut comm_claim_per_round = claim_per_round.commit(blind_claim, gens_1).compress(); + + let mut r: Vec = Vec::new(); + let mut comm_polys: Vec = Vec::new(); + let mut comm_evals: Vec = Vec::new(); + let mut proofs: Vec = Vec::new(); + + for j in 0..num_rounds { + let (poly, comm_poly) = { + let mut eval_point_0 = Scalar::zero(); + let mut eval_point_2 = Scalar::zero(); + + let len = poly_A.len() / 2; + for i in 0..len { + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_A[i], &poly_B[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point); + } + + let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2]; + let poly = UniPoly::from_evals(&evals); + let comm_poly = poly.commit(gens_n, &blinds_poly[j]).compress(); + (poly, comm_poly) + }; + + // append the prover's message to the transcript + comm_poly.append_to_transcript(b"comm_poly", transcript); + comm_polys.push(comm_poly); + + //derive the verifier's challenge for the next round + let r_j = transcript.challenge_scalar(b"challenge_nextround"); + + // bound all tables to the verifier's challenege + poly_A.bound_poly_var_top(&r_j); + poly_B.bound_poly_var_top(&r_j); + + // produce a proof of sum-check and of evaluation + let (proof, claim_next_round, comm_claim_next_round) = { + let eval = poly.evaluate(&r_j); + let comm_eval = eval.commit(&blinds_evals[j], gens_1).compress(); + + // we need to prove the following under homomorphic commitments: + // (1) poly(0) + poly(1) = claim_per_round + // (2) poly(r_j) = eval + + // Our technique is to leverage dot product proofs: + // (1) we can prove: = claim_per_round + // (2) we can prove: >() + }; + + let (proof, _comm_poly, _comm_sc_eval) = DotProductProof::prove( + gens_1, + gens_n, + transcript, + random_tape, + &poly.as_vec(), + &blinds_poly[j], + &a, + &target, + &blind, + ); + + (proof, eval, comm_eval) + }; + + claim_per_round = claim_next_round; + comm_claim_per_round = comm_claim_next_round; + + proofs.push(proof); + r.push(r_j); + comm_evals.push(comm_claim_per_round); + } + + ( + ZKSumcheckInstanceProof::new(comm_polys, comm_evals, proofs), + r, + vec![poly_A[0], poly_B[0]], + blinds_evals[num_rounds - 1], + ) + } + + pub fn prove_cubic_with_additive_term( + claim: &Scalar, + blind_claim: &Scalar, + num_rounds: usize, + poly_A: &mut DensePolynomial, + poly_B: &mut DensePolynomial, + poly_C: &mut DensePolynomial, + poly_D: &mut DensePolynomial, + comb_func: F, + gens_1: &MultiCommitGens, + gens_n: &MultiCommitGens, + transcript: &mut Transcript, + random_tape: &mut RandomTape, + ) -> (Self, Vec, Vec, Scalar) + where + F: Fn(&Scalar, &Scalar, &Scalar, &Scalar) -> Scalar, + { + let (blinds_poly, blinds_evals) = ( + random_tape.random_vector(b"blinds_poly", num_rounds), + random_tape.random_vector(b"blinds_evals", num_rounds), + ); + + let mut claim_per_round = *claim; + let mut comm_claim_per_round = claim_per_round.commit(blind_claim, gens_1).compress(); + + let mut r: Vec = Vec::new(); + let mut comm_polys: Vec = Vec::new(); + let mut comm_evals: Vec = Vec::new(); + let mut proofs: Vec = Vec::new(); + + for j in 0..num_rounds { + let (poly, comm_poly) = { + let mut eval_point_0 = Scalar::zero(); + let mut eval_point_2 = Scalar::zero(); + let mut eval_point_3 = Scalar::zero(); + + let len = poly_A.len() / 2; + for i in 0..len { + // eval 0: bound_func is A(low) + eval_point_0 += comb_func(&poly_A[i], &poly_B[i], &poly_C[i], &poly_D[i]); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i]; + let poly_D_bound_point = poly_D[len + i] + poly_D[len + i] - poly_D[i]; + eval_point_2 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i]; + let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i]; + let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i]; + let poly_D_bound_point = poly_D_bound_point + poly_D[len + i] - poly_D[i]; + eval_point_3 += comb_func( + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + &poly_D_bound_point, + ); + } + + let evals = vec![ + eval_point_0, + claim_per_round - eval_point_0, + eval_point_2, + eval_point_3, + ]; + let poly = UniPoly::from_evals(&evals); + let comm_poly = poly.commit(gens_n, &blinds_poly[j]).compress(); + (poly, comm_poly) + }; + + // append the prover's message to the transcript + comm_poly.append_to_transcript(b"comm_poly", transcript); + comm_polys.push(comm_poly); + + //derive the verifier's challenge for the next round + let r_j = transcript.challenge_scalar(b"challenge_nextround"); + + // bound all tables to the verifier's challenege + poly_A.bound_poly_var_top(&r_j); + poly_B.bound_poly_var_top(&r_j); + poly_C.bound_poly_var_top(&r_j); + poly_D.bound_poly_var_top(&r_j); + + // produce a proof of sum-check and of evaluation + let (proof, claim_next_round, comm_claim_next_round) = { + let eval = poly.evaluate(&r_j); + let comm_eval = eval.commit(&blinds_evals[j], gens_1).compress(); + + // we need to prove the following under homomorphic commitments: + // (1) poly(0) + poly(1) = claim_per_round + // (2) poly(r_j) = eval + + // Our technique is to leverage dot product proofs: + // (1) we can prove: = claim_per_round + // (2) we can prove: >(), + ) + .compress(); + + let blind = { + let blind_sc = if j == 0 { + blind_claim + } else { + &blinds_evals[j - 1] + }; + + let blind_eval = &blinds_evals[j]; + + w[0] * blind_sc + w[1] * blind_eval + }; + + assert_eq!(target.commit(&blind, gens_1).compress(), comm_target); + + let a = { + // the vector to use to decommit for sum-check test + let a_sc = { + let mut a = vec![Scalar::one(); poly.degree() + 1]; + a[0] += Scalar::one(); + a + }; + + // the vector to use to decommit for evaluation + let a_eval = { + let mut a = vec![Scalar::one(); poly.degree() + 1]; + for j in 1..a.len() { + a[j] = a[j - 1] * r_j; + } + a + }; + + // take weighted sum of the two vectors using w + assert_eq!(a_sc.len(), a_eval.len()); + (0..a_sc.len()) + .map(|i| w[0] * a_sc[i] + w[1] * a_eval[i]) + .collect::>() + }; + + let (proof, _comm_poly, _comm_sc_eval) = DotProductProof::prove( + gens_1, + gens_n, + transcript, + random_tape, + &poly.as_vec(), + &blinds_poly[j], + &a, + &target, + &blind, + ); + + (proof, eval, comm_eval) + }; + + proofs.push(proof); + claim_per_round = claim_next_round; + comm_claim_per_round = comm_claim_next_round; + r.push(r_j); + comm_evals.push(comm_claim_per_round); + } + + ( + ZKSumcheckInstanceProof::new(comm_polys, comm_evals, proofs), + r, + vec![poly_A[0], poly_B[0], poly_C[0], poly_D[0]], + blinds_evals[num_rounds - 1], + ) + } +} diff --git a/packages/Spartan-secq/src/timer.rs b/packages/Spartan-secq/src/timer.rs new file mode 100644 index 0000000..8356a35 --- /dev/null +++ b/packages/Spartan-secq/src/timer.rs @@ -0,0 +1,88 @@ +#[cfg(feature = "profile")] +use colored::Colorize; +#[cfg(feature = "profile")] +use core::sync::atomic::AtomicUsize; +#[cfg(feature = "profile")] +use core::sync::atomic::Ordering; +#[cfg(feature = "profile")] +use std::time::Instant; + +#[cfg(feature = "profile")] +pub static CALL_DEPTH: AtomicUsize = AtomicUsize::new(0); + +#[cfg(feature = "profile")] +pub struct Timer { + label: String, + timer: Instant, +} + +#[cfg(feature = "profile")] +impl Timer { + #[inline(always)] + pub fn new(label: &str) -> Self { + let timer = Instant::now(); + CALL_DEPTH.fetch_add(1, Ordering::Relaxed); + let star = "* "; + println!( + "{:indent$}{}{}", + "", + star, + label.yellow().bold(), + indent = 2 * CALL_DEPTH.fetch_add(0, Ordering::Relaxed) + ); + Self { + label: label.to_string(), + timer, + } + } + + #[inline(always)] + pub fn stop(&self) { + let duration = self.timer.elapsed(); + let star = "* "; + println!( + "{:indent$}{}{} {:?}", + "", + star, + self.label.blue().bold(), + duration, + indent = 2 * CALL_DEPTH.fetch_add(0, Ordering::Relaxed) + ); + CALL_DEPTH.fetch_sub(1, Ordering::Relaxed); + } + + #[inline(always)] + pub fn print(msg: &str) { + CALL_DEPTH.fetch_add(1, Ordering::Relaxed); + let star = "* "; + println!( + "{:indent$}{}{}", + "", + star, + msg.to_string().green().bold(), + indent = 2 * CALL_DEPTH.fetch_add(0, Ordering::Relaxed) + ); + CALL_DEPTH.fetch_sub(1, Ordering::Relaxed); + } +} + +#[cfg(not(feature = "profile"))] +pub struct Timer { + _label: String, +} + +#[cfg(not(feature = "profile"))] +impl Timer { + #[inline(always)] + pub fn new(label: &str) -> Self { + Self { + _label: label.to_string(), + } + } + + #[inline(always)] + pub fn stop(&self) {} + + #[inline(always)] + pub fn print(_msg: &str) {} +} diff --git a/packages/Spartan-secq/src/transcript.rs b/packages/Spartan-secq/src/transcript.rs new file mode 100644 index 0000000..a57f150 --- /dev/null +++ b/packages/Spartan-secq/src/transcript.rs @@ -0,0 +1,63 @@ +use super::group::CompressedGroup; +use super::scalar::Scalar; +use merlin::Transcript; + +pub trait ProofTranscript { + fn append_protocol_name(&mut self, protocol_name: &'static [u8]); + fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar); + fn append_point(&mut self, label: &'static [u8], point: &CompressedGroup); + fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar; + fn challenge_vector(&mut self, label: &'static [u8], len: usize) -> Vec; +} + +impl ProofTranscript for Transcript { + fn append_protocol_name(&mut self, protocol_name: &'static [u8]) { + self.append_message(b"protocol-name", protocol_name); + } + + fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar) { + self.append_message(label, &scalar.to_bytes()); + } + + fn append_point(&mut self, label: &'static [u8], point: &CompressedGroup) { + self.append_message(label, point.as_bytes()); + } + + fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar { + let mut buf = [0u8; 64]; + self.challenge_bytes(label, &mut buf); + Scalar::from_bytes_wide(&buf) + } + + fn challenge_vector(&mut self, label: &'static [u8], len: usize) -> Vec { + (0..len) + .map(|_i| self.challenge_scalar(label)) + .collect::>() + } +} + +pub trait AppendToTranscript { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript); +} + +impl AppendToTranscript for Scalar { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_scalar(label, self); + } +} + +impl AppendToTranscript for [Scalar] { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_message(label, b"begin_append_vector"); + for item in self { + transcript.append_scalar(label, item); + } + transcript.append_message(label, b"end_append_vector"); + } +} + +impl AppendToTranscript for CompressedGroup { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_point(label, self); + } +} diff --git a/packages/Spartan-secq/src/unipoly.rs b/packages/Spartan-secq/src/unipoly.rs new file mode 100644 index 0000000..dcc3918 --- /dev/null +++ b/packages/Spartan-secq/src/unipoly.rs @@ -0,0 +1,182 @@ +use super::commitments::{Commitments, MultiCommitGens}; +use super::group::GroupElement; +use super::scalar::{Scalar, ScalarFromPrimitives}; +use super::transcript::{AppendToTranscript, ProofTranscript}; +use merlin::Transcript; +use serde::{Deserialize, Serialize}; + +// ax^2 + bx + c stored as vec![c,b,a] +// ax^3 + bx^2 + cx + d stored as vec![d,c,b,a] +#[derive(Debug)] +pub struct UniPoly { + coeffs: Vec, +} + +// ax^2 + bx + c stored as vec![c,a] +// ax^3 + bx^2 + cx + d stored as vec![d,b,a] +#[derive(Serialize, Deserialize, Debug)] +pub struct CompressedUniPoly { + coeffs_except_linear_term: Vec, +} + +impl UniPoly { + pub fn from_evals(evals: &[Scalar]) -> Self { + // we only support degree-2 or degree-3 univariate polynomials + assert!(evals.len() == 3 || evals.len() == 4); + let coeffs = if evals.len() == 3 { + // ax^2 + bx + c + let two_inv = (2_usize).to_scalar().invert().unwrap(); + + let c = evals[0]; + let a = two_inv * (evals[2] - evals[1] - evals[1] + c); + let b = evals[1] - c - a; + vec![c, b, a] + } else { + // ax^3 + bx^2 + cx + d + let two_inv = (2_usize).to_scalar().invert().unwrap(); + let six_inv = (6_usize).to_scalar().invert().unwrap(); + + let d = evals[0]; + let a = six_inv + * (evals[3] - evals[2] - evals[2] - evals[2] + evals[1] + evals[1] + evals[1] - evals[0]); + let b = two_inv + * (evals[0] + evals[0] - evals[1] - evals[1] - evals[1] - evals[1] - evals[1] + + evals[2] + + evals[2] + + evals[2] + + evals[2] + - evals[3]); + let c = evals[1] - d - a - b; + vec![d, c, b, a] + }; + + UniPoly { coeffs } + } + + pub fn degree(&self) -> usize { + self.coeffs.len() - 1 + } + + pub fn as_vec(&self) -> Vec { + self.coeffs.clone() + } + + pub fn eval_at_zero(&self) -> Scalar { + self.coeffs[0] + } + + pub fn eval_at_one(&self) -> Scalar { + (0..self.coeffs.len()).map(|i| self.coeffs[i]).sum() + } + + pub fn evaluate(&self, r: &Scalar) -> Scalar { + let mut eval = self.coeffs[0]; + let mut power = *r; + for i in 1..self.coeffs.len() { + eval += power * self.coeffs[i]; + power *= r; + } + eval + } + + pub fn compress(&self) -> CompressedUniPoly { + let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat(); + assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len()); + CompressedUniPoly { + coeffs_except_linear_term, + } + } + + pub fn commit(&self, gens: &MultiCommitGens, blind: &Scalar) -> GroupElement { + self.coeffs.commit(blind, gens) + } +} + +impl CompressedUniPoly { + // we require eval(0) + eval(1) = hint, so we can solve for the linear term as: + // linear_term = hint - 2 * constant_term - deg2 term - deg3 term + pub fn decompress(&self, hint: &Scalar) -> UniPoly { + let mut linear_term = + hint - self.coeffs_except_linear_term[0] - self.coeffs_except_linear_term[0]; + for i in 1..self.coeffs_except_linear_term.len() { + linear_term -= self.coeffs_except_linear_term[i]; + } + + let mut coeffs = vec![self.coeffs_except_linear_term[0], linear_term]; + coeffs.extend(&self.coeffs_except_linear_term[1..]); + assert_eq!(self.coeffs_except_linear_term.len() + 1, coeffs.len()); + UniPoly { coeffs } + } +} + +impl AppendToTranscript for UniPoly { + fn append_to_transcript(&self, label: &'static [u8], transcript: &mut Transcript) { + transcript.append_message(label, b"UniPoly_begin"); + for i in 0..self.coeffs.len() { + transcript.append_scalar(b"coeff", &self.coeffs[i]); + } + transcript.append_message(label, b"UniPoly_end"); + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_from_evals_quad() { + // polynomial is 2x^2 + 3x + 1 + let e0 = Scalar::one(); + let e1 = (6_usize).to_scalar(); + let e2 = (15_usize).to_scalar(); + let evals = vec![e0, e1, e2]; + let poly = UniPoly::from_evals(&evals); + + assert_eq!(poly.eval_at_zero(), e0); + assert_eq!(poly.eval_at_one(), e1); + assert_eq!(poly.coeffs.len(), 3); + assert_eq!(poly.coeffs[0], Scalar::one()); + assert_eq!(poly.coeffs[1], (3_usize).to_scalar()); + assert_eq!(poly.coeffs[2], (2_usize).to_scalar()); + + let hint = e0 + e1; + let compressed_poly = poly.compress(); + let decompressed_poly = compressed_poly.decompress(&hint); + for i in 0..decompressed_poly.coeffs.len() { + assert_eq!(decompressed_poly.coeffs[i], poly.coeffs[i]); + } + + let e3 = (28_usize).to_scalar(); + assert_eq!(poly.evaluate(&(3_usize).to_scalar()), e3); + } + + #[test] + fn test_from_evals_cubic() { + // polynomial is x^3 + 2x^2 + 3x + 1 + let e0 = Scalar::one(); + let e1 = (7_usize).to_scalar(); + let e2 = (23_usize).to_scalar(); + let e3 = (55_usize).to_scalar(); + let evals = vec![e0, e1, e2, e3]; + let poly = UniPoly::from_evals(&evals); + + assert_eq!(poly.eval_at_zero(), e0); + assert_eq!(poly.eval_at_one(), e1); + assert_eq!(poly.coeffs.len(), 4); + assert_eq!(poly.coeffs[0], Scalar::one()); + assert_eq!(poly.coeffs[1], (3_usize).to_scalar()); + assert_eq!(poly.coeffs[2], (2_usize).to_scalar()); + assert_eq!(poly.coeffs[3], (1_usize).to_scalar()); + + let hint = e0 + e1; + let compressed_poly = poly.compress(); + let decompressed_poly = compressed_poly.decompress(&hint); + for i in 0..decompressed_poly.coeffs.len() { + assert_eq!(decompressed_poly.coeffs[i], poly.coeffs[i]); + } + + let e4 = (109_usize).to_scalar(); + assert_eq!(poly.evaluate(&(4_usize).to_scalar()), e4); + } +} diff --git a/packages/circuit_reader/Cargo.toml b/packages/circuit_reader/Cargo.toml new file mode 100644 index 0000000..aad7a96 --- /dev/null +++ b/packages/circuit_reader/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "circuit_reader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bincode = "1.3.3" +secq256k1 = { path = "../secq256k1" } +spartan = { path = "../Spartan-secq" } +ff = "0.12.0" +byteorder = "1.4.3" +group = "0.12.0" +itertools = "0.9.0" + +[[bin]] +name = "gen_spartan_inst" +path = "src/bin/gen_spartan_inst.rs" + + + diff --git a/packages/circuit_reader/src/bin/gen_spartan_inst.rs b/packages/circuit_reader/src/bin/gen_spartan_inst.rs new file mode 100644 index 0000000..3edb39f --- /dev/null +++ b/packages/circuit_reader/src/bin/gen_spartan_inst.rs @@ -0,0 +1,24 @@ +#![allow(non_snake_case)] +use bincode; +use circuit_reader::load_as_spartan_inst; +use std::env::{args, current_dir}; +use std::fs::File; +use std::io::Write; + +fn main() { + let circom_r1cs_path = args().nth(1).unwrap(); + let output_path = args().nth(2).unwrap(); + let num_pub_inputs = args().nth(3).unwrap().parse::().unwrap(); + + let root = current_dir().unwrap(); + let circom_r1cs_path = root.join(circom_r1cs_path); + let spartan_inst = load_as_spartan_inst(circom_r1cs_path, num_pub_inputs); + let sparta_inst_bytes = bincode::serialize(&spartan_inst).unwrap(); + + File::create(root.join(output_path.clone())) + .unwrap() + .write_all(sparta_inst_bytes.as_slice()) + .unwrap(); + + println!("Written Spartan circuit to {}", output_path); +} diff --git a/packages/spartan_wasm/src/circom_reader.rs b/packages/circuit_reader/src/circom_reader.rs similarity index 100% rename from packages/spartan_wasm/src/circom_reader.rs rename to packages/circuit_reader/src/circom_reader.rs diff --git a/packages/spartan_wasm/src/bin/gen_spartan_inst.rs b/packages/circuit_reader/src/lib.rs similarity index 61% rename from packages/spartan_wasm/src/bin/gen_spartan_inst.rs rename to packages/circuit_reader/src/lib.rs index 56590b8..1efd3ff 100644 --- a/packages/spartan_wasm/src/bin/gen_spartan_inst.rs +++ b/packages/circuit_reader/src/lib.rs @@ -1,41 +1,15 @@ -#![allow(non_snake_case)] -use bincode; +mod circom_reader; + +use circom_reader::{load_r1cs_from_bin_file, R1CS}; use ff::PrimeField; use libspartan::Instance; use secq256k1::AffinePoint; use secq256k1::FieldBytes; -use spartan_wasm::circom_reader::{load_r1cs_from_bin_file, R1CS}; -use std::env::{args, current_dir}; -use std::fs::File; -use std::io::Write; use std::path::PathBuf; -fn main() { - let circuit_path = args().nth(1).unwrap(); - let output_path = args().nth(2).unwrap(); - let num_pub_inputs = args().nth(3).unwrap().parse::().unwrap(); - - let root = current_dir().unwrap(); - let circuit_path = root.join(circuit_path); - let spartan_inst = load_as_spartan_inst(circuit_path, num_pub_inputs); - let sparta_inst_bytes = bincode::serialize(&spartan_inst).unwrap(); - - File::create(root.join(output_path.clone())) - .unwrap() - .write_all(sparta_inst_bytes.as_slice()) - .unwrap(); - - println!("Written Spartan circuit to {}", output_path); -} - pub fn load_as_spartan_inst(circuit_file: PathBuf, num_pub_inputs: usize) -> Instance { - let root = current_dir().unwrap(); - - let circuit_file = root.join(circuit_file); let (r1cs, _) = load_r1cs_from_bin_file::(&circuit_file); - let spartan_inst = convert_to_spartan_r1cs(&r1cs, num_pub_inputs); - spartan_inst } diff --git a/packages/circuits/package.json b/packages/circuits/package.json index 56810cf..882c2d8 100644 --- a/packages/circuits/package.json +++ b/packages/circuits/package.json @@ -1,6 +1,6 @@ { - "name": "circuits", - "version": "1.0.0", + "name": "@personaelabs/spartan-ecdsa-circuits", + "version": "0.1.0", "main": "index.js", "license": "MIT", "dependencies": { @@ -18,4 +18,4 @@ "ts-jest": "^29.0.3", "typescript": "^4.9.4" } -} \ No newline at end of file +} diff --git a/packages/circuits/poseidon/poseidon_constants.circom b/packages/circuits/poseidon/poseidon_constants.circom index 9012740..e6d5bde 100644 --- a/packages/circuits/poseidon/poseidon_constants.circom +++ b/packages/circuits/poseidon/poseidon_constants.circom @@ -1,3 +1,5 @@ +pragma circom 2.1.2; + function ROUND_KEYS() { return [ 15180568604901803243989155929934437997245952775071395385994322939386074967328, @@ -213,4 +215,4 @@ function MDS_MATRIX() { 70274477372358662369456035572054501601454406272695978931839980644925236550307 ] ]; -} \ No newline at end of file +} diff --git a/packages/lib/package.json b/packages/lib/package.json index 201dffa..5ad2e0c 100644 --- a/packages/lib/package.json +++ b/packages/lib/package.json @@ -1,6 +1,6 @@ { "name": "@personaelabs/spartan-ecdsa", - "version": "1.0.2", + "version": "2.0.0", "main": "./build/lib.js", "types": "./build/lib.d.ts", "license": "MIT", diff --git a/packages/lib/src/helpers/public_input.ts b/packages/lib/src/helpers/public_input.ts index 6484e5e..c701aee 100644 --- a/packages/lib/src/helpers/public_input.ts +++ b/packages/lib/src/helpers/public_input.ts @@ -109,7 +109,7 @@ export class PublicInput { /** * Compute the group elements T and U for efficient ecdsa - * http://localhost:1313/posts/efficient-ecdsa-1/ + * https://personaelabs.org/posts/efficient-ecdsa-1/ */ export const computeEffEcdsaPubInput = ( r: bigint, diff --git a/packages/lib/src/helpers/tree.ts b/packages/lib/src/helpers/tree.ts index 67c2f96..e9de3e6 100644 --- a/packages/lib/src/helpers/tree.ts +++ b/packages/lib/src/helpers/tree.ts @@ -20,6 +20,14 @@ export class Tree { this.treeInner.insert(leaf); } + delete(index: number) { + this.treeInner.delete(index); + } + + leaves(): bigint[] { + return this.treeInner.leaves; + } + root(): bigint { return this.treeInner.root; } diff --git a/packages/lib/src/wasm/wasm.js b/packages/lib/src/wasm/wasm.js index 92b7302..aaec619 100644 --- a/packages/lib/src/wasm/wasm.js +++ b/packages/lib/src/wasm/wasm.js @@ -1,7 +1,6 @@ - let wasm; -const heap = new Array(32).fill(undefined); +const heap = new Array(128).fill(undefined); heap.push(undefined, null, true, false); @@ -10,7 +9,7 @@ function getObject(idx) { return heap[idx]; } let heap_next = heap.length; function dropObject(idx) { - if (idx < 36) return; + if (idx < 132) return; heap[idx] = heap_next; heap_next = idx; } @@ -25,10 +24,10 @@ const cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: tru cachedTextDecoder.decode(); -let cachedUint8Memory0 = new Uint8Array(); +let cachedUint8Memory0 = null; function getUint8Memory0() { - if (cachedUint8Memory0.buffer !== wasm.memory.buffer) { + if (cachedUint8Memory0 === null || cachedUint8Memory0.buffer !== wasm.memory.buffer) { cachedUint8Memory0 = new Uint8Array(wasm.memory.buffer); } return cachedUint8Memory0; @@ -61,10 +60,10 @@ function passArray8ToWasm0(arg, malloc) { return ptr; } -let cachedInt32Memory0 = new Int32Array(); +let cachedInt32Memory0 = null; function getInt32Memory0() { - if (cachedInt32Memory0.buffer !== wasm.memory.buffer) { + if (cachedInt32Memory0 === null || cachedInt32Memory0.buffer !== wasm.memory.buffer) { cachedInt32Memory0 = new Int32Array(wasm.memory.buffer); } return cachedInt32Memory0; @@ -248,6 +247,15 @@ async function load(module, imports) { function getImports() { const imports = {}; imports.wbg = {}; + imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) { + getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2)); + }, arguments) }; + imports.wbg.__wbindgen_object_drop_ref = function(arg0) { + takeObject(arg0); + }; + imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) { + getObject(arg0).getRandomValues(getObject(arg1)); + }, arguments) }; imports.wbg.__wbg_crypto_e1d53a1d73fb10b8 = function(arg0) { const ret = getObject(arg0).crypto; return addHeapObject(ret); @@ -273,9 +281,6 @@ function getImports() { const ret = typeof(getObject(arg0)) === 'string'; return ret; }; - imports.wbg.__wbindgen_object_drop_ref = function(arg0) { - takeObject(arg0); - }; imports.wbg.__wbg_msCrypto_6e7d3e1f92610cbb = function(arg0) { const ret = getObject(arg0).msCrypto; return addHeapObject(ret); @@ -292,17 +297,11 @@ function getImports() { const ret = getStringFromWasm0(arg0, arg1); return addHeapObject(ret); }; - imports.wbg.__wbg_getRandomValues_805f1c3d65988a5a = function() { return handleError(function (arg0, arg1) { - getObject(arg0).getRandomValues(getObject(arg1)); - }, arguments) }; - imports.wbg.__wbg_randomFillSync_6894564c2c334c42 = function() { return handleError(function (arg0, arg1, arg2) { - getObject(arg0).randomFillSync(getArrayU8FromWasm0(arg1, arg2)); - }, arguments) }; - imports.wbg.__wbg_newnoargs_b5b063fc6c2f0376 = function(arg0, arg1) { + imports.wbg.__wbg_newnoargs_2b8b6bd7753c76ba = function(arg0, arg1) { const ret = new Function(getStringFromWasm0(arg0, arg1)); return addHeapObject(ret); }; - imports.wbg.__wbg_call_97ae9d8645dc388b = function() { return handleError(function (arg0, arg1) { + imports.wbg.__wbg_call_95d1ea488d03e4e8 = function() { return handleError(function (arg0, arg1) { const ret = getObject(arg0).call(getObject(arg1)); return addHeapObject(ret); }, arguments) }; @@ -310,19 +309,19 @@ function getImports() { const ret = getObject(arg0); return addHeapObject(ret); }; - imports.wbg.__wbg_self_6d479506f72c6a71 = function() { return handleError(function () { + imports.wbg.__wbg_self_e7c1f827057f6584 = function() { return handleError(function () { const ret = self.self; return addHeapObject(ret); }, arguments) }; - imports.wbg.__wbg_window_f2557cc78490aceb = function() { return handleError(function () { + imports.wbg.__wbg_window_a09ec664e14b1b81 = function() { return handleError(function () { const ret = window.window; return addHeapObject(ret); }, arguments) }; - imports.wbg.__wbg_globalThis_7f206bda628d5286 = function() { return handleError(function () { + imports.wbg.__wbg_globalThis_87cbb8506fecf3a9 = function() { return handleError(function () { const ret = globalThis.globalThis; return addHeapObject(ret); }, arguments) }; - imports.wbg.__wbg_global_ba75c50d1cf384f4 = function() { return handleError(function () { + imports.wbg.__wbg_global_c85a9259e621f3db = function() { return handleError(function () { const ret = global.global; return addHeapObject(ret); }, arguments) }; @@ -330,30 +329,30 @@ function getImports() { const ret = getObject(arg0) === undefined; return ret; }; - imports.wbg.__wbg_call_168da88779e35f61 = function() { return handleError(function (arg0, arg1, arg2) { + imports.wbg.__wbg_call_9495de66fdbe016b = function() { return handleError(function (arg0, arg1, arg2) { const ret = getObject(arg0).call(getObject(arg1), getObject(arg2)); return addHeapObject(ret); }, arguments) }; - imports.wbg.__wbg_buffer_3f3d764d4747d564 = function(arg0) { + imports.wbg.__wbg_buffer_cf65c07de34b9a08 = function(arg0) { const ret = getObject(arg0).buffer; return addHeapObject(ret); }; - imports.wbg.__wbg_new_8c3f0052272a457a = function(arg0) { + imports.wbg.__wbg_new_537b7341ce90bb31 = function(arg0) { const ret = new Uint8Array(getObject(arg0)); return addHeapObject(ret); }; - imports.wbg.__wbg_set_83db9690f9353e79 = function(arg0, arg1, arg2) { + imports.wbg.__wbg_set_17499e8aa4003ebd = function(arg0, arg1, arg2) { getObject(arg0).set(getObject(arg1), arg2 >>> 0); }; - imports.wbg.__wbg_length_9e1ae1900cb0fbd5 = function(arg0) { + imports.wbg.__wbg_length_27a2afe8ab42b09f = function(arg0) { const ret = getObject(arg0).length; return ret; }; - imports.wbg.__wbg_newwithlength_f5933855e4f48a19 = function(arg0) { + imports.wbg.__wbg_newwithlength_b56c882b57805732 = function(arg0) { const ret = new Uint8Array(arg0 >>> 0); return addHeapObject(ret); }; - imports.wbg.__wbg_subarray_58ad4efbb5bcb886 = function(arg0, arg1, arg2) { + imports.wbg.__wbg_subarray_7526649b91a252a6 = function(arg0, arg1, arg2) { const ret = getObject(arg0).subarray(arg1 >>> 0, arg2 >>> 0); return addHeapObject(ret); }; @@ -393,8 +392,8 @@ function initMemory(imports, maybe_memory) { function finalizeInit(instance, module) { wasm = instance.exports; init.__wbindgen_wasm_module = module; - cachedInt32Memory0 = new Int32Array(); - cachedUint8Memory0 = new Uint8Array(); + cachedInt32Memory0 = null; + cachedUint8Memory0 = null; wasm.__wbindgen_start(); return wasm; diff --git a/packages/spartan_wasm/Cargo.toml b/packages/spartan_wasm/Cargo.toml index d04a352..b362881 100644 --- a/packages/spartan_wasm/Cargo.toml +++ b/packages/spartan_wasm/Cargo.toml @@ -32,14 +32,3 @@ poseidon = { path = "../poseidon" } itertools = "0.9.0" group = "0.12.0" -# Do not compile these dependencies when targeting wasm -#[target.'cfg(not(target_family = "wasm"))'.dependencies] -#nova-scotia = { git = "https://github.com/DanTehrani/Nova-Scotia.git" } -#nova-snark = "0.9.0" -#ff = "0.12.1" -#ark-std = { version = "0.3.0", features = ["print-trace"] } - - -[[bin]] -name = "gen_spartan_inst" -path = "src/bin/gen_spartan_inst.rs" diff --git a/packages/spartan_wasm/src/lib.rs b/packages/spartan_wasm/src/lib.rs index 406e69d..ce1d9f8 100644 --- a/packages/spartan_wasm/src/lib.rs +++ b/packages/spartan_wasm/src/lib.rs @@ -1,4 +1 @@ pub mod wasm; - -#[cfg(not(target_family = "wasm"))] -pub mod circom_reader; diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..c524a79 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,2 @@ +cargo test --release && +yarn lerna run test \ No newline at end of file