mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
3 Commits
v8.4.1
...
example-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cb303b149 | ||
|
|
9fb78c36e0 | ||
|
|
074db5d229 |
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
|
||||
65
.github/workflows/rust.yml
vendored
65
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -198,8 +198,10 @@ jobs:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Install wasm runner
|
||||
run: cargo install wasm-server-runner
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -212,7 +214,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -229,7 +231,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -284,7 +286,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -310,7 +312,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -343,15 +345,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -411,11 +416,11 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
@@ -445,7 +450,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -455,7 +460,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: fuzz tests (EVM)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
|
||||
# - name: fuzz tests
|
||||
@@ -468,7 +473,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -486,7 +491,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -503,7 +508,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -520,7 +525,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -530,7 +535,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -541,7 +546,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -563,7 +568,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install solc
|
||||
@@ -571,7 +576,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -587,7 +592,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -618,7 +623,7 @@ jobs:
|
||||
python-version: "3.9"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -628,7 +633,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
|
||||
7
.github/workflows/wasm.yml
vendored
7
.github/workflows/wasm.yml
vendored
@@ -22,15 +22,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
|
||||
558
Cargo.lock
generated
558
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
|
||||
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
@@ -34,13 +34,10 @@ bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
|
||||
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
|
||||
indicatif = {version = "0.17.5", features = ["rayon"]}
|
||||
gag = { version = "1.0.0", default_features = false}
|
||||
instant = { version = "0.1" }
|
||||
@@ -161,7 +158,6 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidi
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
|
||||
@@ -6,7 +6,6 @@ use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -490,7 +489,6 @@ pub fn runconv() {
|
||||
strategy,
|
||||
pi_for_real_prover,
|
||||
&mut transcript,
|
||||
params.n(),
|
||||
);
|
||||
assert!(verify.is_ok());
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 109 KiB After Width: | Height: | Size: 109 KiB |
@@ -11,8 +11,8 @@ use ezkl::execute::run;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ezkl::logger::init_logger;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use log::{error, info};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(feature = "icicle")]
|
||||
@@ -25,7 +25,6 @@ use std::error::Error;
|
||||
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Cli::parse();
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
@@ -33,7 +32,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
let res = run(args.command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
@@ -45,7 +44,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn main() {}
|
||||
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn banner() {
|
||||
let ell: Vec<&str> = vec![
|
||||
"for Neural Networks",
|
||||
|
||||
@@ -41,7 +41,7 @@ pub struct KZGChip {
|
||||
}
|
||||
|
||||
impl KZGChip {
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
/// Returns the number of inputs to the hash function
|
||||
pub fn commit(
|
||||
message: Vec<Fp>,
|
||||
degree: u32,
|
||||
|
||||
@@ -15,7 +15,7 @@ use halo2_proofs::{
|
||||
Instance, Selector, TableColumn,
|
||||
},
|
||||
};
|
||||
use log::{debug, trace};
|
||||
use log::{trace, warn};
|
||||
|
||||
/// A simple [`FloorPlanner`] that performs minimal optimizations.
|
||||
#[derive(Debug)]
|
||||
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
|
||||
Error::Synthesis
|
||||
})?;
|
||||
if !self.regions.contains_key(&index) {
|
||||
debug!("spawning module {}", index)
|
||||
warn!("spawning module {}", index)
|
||||
};
|
||||
self.current_module = index;
|
||||
}
|
||||
|
||||
@@ -125,8 +125,8 @@ impl BaseOp {
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::Range { .. } => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
BaseOp::IsZero => 1,
|
||||
BaseOp::IsBoolean => 1,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::ops::base::BaseOp,
|
||||
@@ -62,22 +61,6 @@ pub enum CheckMode {
|
||||
UNSAFE,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CheckMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CheckMode::SAFE => write!(f, "safe"),
|
||||
CheckMode::UNSAFE => write!(f, "unsafe"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CheckMode {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for CheckMode {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -100,19 +83,6 @@ pub struct Tolerance {
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
@@ -306,20 +276,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => vec![qis[1].clone()],
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
@@ -553,15 +512,14 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check =
|
||||
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
let range_check = if !self.range_checks.contains_key(&range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
self.range_checks.insert(range, range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
// import run args from model
|
||||
|
||||
@@ -68,6 +69,14 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
num_classes: usize,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
@@ -75,6 +84,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::ScatterElements { .. } => vec![0, 2],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
_ => vec![],
|
||||
}
|
||||
@@ -88,42 +98,176 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
let (res, intermediate_lookups) = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => {
|
||||
let res = tensor::ops::max_axes(&x, axes)?;
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::ReduceMin { axes, .. } => {
|
||||
let res = tensor::ops::min_axes(&x, axes)?;
|
||||
let min_plus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(min(x + 1) - x)
|
||||
let inter_1 = (min_plus_one - x.clone())?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
|
||||
// if denom is a round number and use_range_check_for_int is true, use range check check
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
|
||||
(res, vec![-divisor.clone(), divisor])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
);
|
||||
// if scale is a round number and use_range_check_for_int is true, use range check check
|
||||
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let err_tol = Tensor::from(
|
||||
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
|
||||
);
|
||||
(res, vec![-err_tol.clone(), err_tol])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::ReduceArgMax { dim } => {
|
||||
let res = tensor::ops::argmax_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
let res = tensor::ops::argmin_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::TopK { dim, k, largest } => {
|
||||
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
let mut inter_equals = x
|
||||
.clone()
|
||||
.into_iter()
|
||||
.flat_map(|elem| {
|
||||
tensor::ops::equals(&res, &vec![elem].into_iter().into())
|
||||
.unwrap()
|
||||
.1
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// sort in descending order and take pairwise differences
|
||||
inter_equals.push(
|
||||
x.into_iter()
|
||||
.sorted()
|
||||
.tuple_windows()
|
||||
.map(|(a, b)| b - a)
|
||||
.into(),
|
||||
);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let src = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
|
||||
let src = inputs[2].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
} => {
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(
|
||||
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
vec![inter_1, inter_2],
|
||||
)
|
||||
}
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -135,7 +279,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
(
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
@@ -162,7 +309,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups,
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
@@ -216,6 +366,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
HybridOp::TopK { k, dim, largest } => {
|
||||
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
|
||||
}
|
||||
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
|
||||
}
|
||||
@@ -288,7 +440,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Div { denom: *denom },
|
||||
&LookupOp::Div {
|
||||
denom: denom.clone(),
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
@@ -299,7 +453,26 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
layouts::gather(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
|
||||
@@ -128,7 +128,13 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
let mut scaled_unit =
|
||||
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
|
||||
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
|
||||
region.increment(scaled_unit.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
@@ -160,14 +166,27 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
// this is now of scale 2 * scale hence why we rescaled the unit scale
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), scaled_unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
|
||||
|
||||
// debug print the diff
|
||||
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
@@ -214,7 +233,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (i, input) in values.iter_mut().enumerate() {
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_with_duplication(
|
||||
&config.inputs[i],
|
||||
@@ -770,7 +789,14 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_input = region.assign(&config.inputs[0], &input)?;
|
||||
|
||||
// now assert all elems are 0 or 1
|
||||
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
|
||||
let assigned_output = region.assign(&config.inputs[1], &output)?;
|
||||
if !region.is_dummy() {
|
||||
for i in 0..assigned_output.len() {
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
|
||||
|
||||
let sum = sum(config, region, &[assigned_output.clone()])?;
|
||||
@@ -1142,7 +1168,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1211,7 +1237,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1675,28 +1701,10 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff_inverse = diff.inverse()?;
|
||||
let product_diff_and_invert =
|
||||
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
|
||||
|
||||
// constant of 1
|
||||
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
||||
ones.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
|
||||
|
||||
// subtract
|
||||
let output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[ones.into(), product_diff_and_invert],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
|
||||
Ok(output)
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Xor boolean operation
|
||||
@@ -1760,7 +1768,21 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into();
|
||||
|
||||
// make sure mask is boolean
|
||||
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
|
||||
let assigned_mask = region.assign(&config.inputs[1], mask)?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_mask.len())
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_mask.len());
|
||||
|
||||
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
|
||||
|
||||
@@ -1854,15 +1876,17 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
let shape = &res[0].dims()[2..];
|
||||
let mut last_elem = res[1..]
|
||||
.iter()
|
||||
.try_fold(res[0].clone(), |acc, elem| acc.concat(elem.clone()))?;
|
||||
.fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?;
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = div(
|
||||
last_elem = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[last_elem],
|
||||
F::from((kernel_shape.0 * kernel_shape.1) as u64),
|
||||
&LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
Ok(last_elem)
|
||||
@@ -2359,60 +2383,18 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
// get zero constants indices
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
let output = region.assign(&config.inputs[1], &values[0])?;
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -2420,6 +2402,7 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
region.increment(output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -2469,7 +2452,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
range: &crate::circuit::table::Range,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_range_check(*range)?;
|
||||
region.add_used_range_check(*range);
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
@@ -2488,7 +2471,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
let (x, y, z) = config
|
||||
.lookup_input
|
||||
.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.range_check_selectors.get(&(*range, x, y));
|
||||
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
@@ -2515,7 +2498,7 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
nl: &LookupOp,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_lookup(nl.clone(), values)?;
|
||||
region.add_used_lookup(nl.clone());
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
@@ -2598,6 +2581,22 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// mean function layout
|
||||
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let x = &values[0];
|
||||
|
||||
let sum_x = sum(config, region, &[x.clone()])?;
|
||||
let nl = LookupOp::Div {
|
||||
denom: utils::F32((scale * x.len()) as f32),
|
||||
};
|
||||
nonlinearity(config, region, &[sum_x], &nl)
|
||||
}
|
||||
|
||||
/// Argmax
|
||||
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2710,8 +2709,24 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
)?;
|
||||
// relu(x - max(x - 1))
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
// constraining relu(x - max(x - 1)) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
|
||||
// sum(relu(x - max(x - 1)))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2722,7 +2737,13 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
// constraining 1 - sum(relu(x - max(x - 1))) = 0
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
|
||||
Ok(assigned_max_val)
|
||||
}
|
||||
@@ -2767,8 +2788,23 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
// relu(min(x + 1) - x)
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
// constraining relu(min(x + 1) - x) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
|
||||
// sum(relu(min(x + 1) - x))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2779,8 +2815,14 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
let relu_one_minus_sum_relu =
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
// constraining product to 0
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
|
||||
Ok(assigned_min_val)
|
||||
}
|
||||
@@ -2999,8 +3041,15 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Add the lower_bound and upper_bound
|
||||
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
|
||||
|
||||
// Assign the sum tensor to the inputs
|
||||
region.assign(&config.inputs[1], &sum)?;
|
||||
|
||||
// Constrain the sum to be all zeros
|
||||
is_zero_identity(config, region, &[sum.clone()], false)?;
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(sum.len());
|
||||
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
@@ -227,7 +227,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
|
||||
@@ -29,6 +29,7 @@ pub mod region;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
@@ -177,6 +178,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,7 +201,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
true,
|
||||
)?))
|
||||
}
|
||||
_ => Ok(Some(super::layouts::identity(
|
||||
@@ -302,7 +303,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -10,14 +9,6 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -90,8 +81,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -214,36 +203,14 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
Ok(ForwardResult {
|
||||
output: res,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn layout(
|
||||
@@ -284,26 +251,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -405,8 +352,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Region error
|
||||
@@ -68,8 +66,6 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
total_constants: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -91,8 +87,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -110,8 +104,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +120,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,8 +141,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +180,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
|
||||
/// Create a new region context per loop iteration
|
||||
/// hacky but it works
|
||||
|
||||
pub fn dummy_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
@@ -203,8 +190,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
|
||||
@@ -236,9 +221,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
@@ -252,11 +234,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
})?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
@@ -276,54 +253,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(
|
||||
&mut self,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if range.0 > range.1 {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
}
|
||||
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if the region is dummy
|
||||
pub fn is_dummy(&self) -> bool {
|
||||
self.region.is_none()
|
||||
}
|
||||
|
||||
/// add used lookup
|
||||
pub fn add_used_lookup(
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_range_check(&mut self, range: Range) {
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
@@ -351,16 +293,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i128 {
|
||||
self.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i128 {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
|
||||
@@ -246,7 +246,7 @@ mod matmul_col_overflow {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow_double_col {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -349,13 +349,8 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -366,7 +361,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -468,13 +463,8 @@ mod matmul_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1150,7 +1140,7 @@ mod conv {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1272,13 +1262,8 @@ mod conv_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1290,7 +1275,7 @@ mod conv_col_ultra_overflow {
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1427,13 +1412,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -2363,7 +2343,7 @@ mod lookup_ultra_overflow {
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
poly::commitment::{Params, ParamsProver},
|
||||
poly::commitment::ParamsProver,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -2467,13 +2447,8 @@ mod lookup_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
||||
128
src/commands.rs
128
src/commands.rs
@@ -1,4 +1,4 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -9,9 +9,8 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::{error::Error, str::FromStr};
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{pfsys::ProofType, RunArgs};
|
||||
|
||||
@@ -86,11 +85,15 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for TranscriptType {
|
||||
@@ -135,27 +138,17 @@ impl Default for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CalibrationTarget {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
impl ToString for CalibrationTarget {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CalibrationTarget {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,36 +169,6 @@ impl From<&str> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
inner: H160,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<H160Flag> for H160 {
|
||||
fn from(val: H160Flag) -> H160 {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl ToFlags for H160Flag {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{:#x}", self.inner)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<&str> for H160Flag {
|
||||
fn from(s: &str) -> Self {
|
||||
Self {
|
||||
inner: H160::from_str(s).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CalibrationTarget {
|
||||
@@ -238,7 +201,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
@@ -279,7 +242,7 @@ impl Cli {
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub enum Commands {
|
||||
#[cfg(feature = "empty-cmd")]
|
||||
/// Creates an empty buffer
|
||||
@@ -373,9 +336,9 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
|
||||
only_range_check_rebase: bool,
|
||||
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
|
||||
#[arg(long)]
|
||||
div_rebasing: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -546,7 +509,7 @@ pub enum Commands {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEvmData {
|
||||
SetupTestEVMData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -574,7 +537,7 @@ pub enum Commands {
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long)]
|
||||
addr: H160Flag,
|
||||
addr: H160,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -624,9 +587,9 @@ pub enum Commands {
|
||||
check_mode: CheckMode,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
/// Creates an EVM verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
CreateEVMVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -649,9 +612,9 @@ pub enum Commands {
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
/// Creates an EVM verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEvmVK {
|
||||
CreateEVMVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -669,9 +632,9 @@ pub enum Commands {
|
||||
abi_path: PathBuf,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEvmDataAttestation {
|
||||
CreateEVMDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
@@ -691,9 +654,9 @@ pub enum Commands {
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
/// Creates an EVM verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
CreateEVMVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -732,9 +695,6 @@ pub enum Commands {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
},
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
@@ -816,23 +776,31 @@ pub enum Commands {
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
/// Verifies a proof using a local EVM executor, returning accept or reject
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEvm {
|
||||
VerifyEVM {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
addr_verifier: H160Flag,
|
||||
addr_verifier: H160,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long)]
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_da: Option<H160>,
|
||||
// is the vk rendered seperately, if so specify an address
|
||||
#[arg(long)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
addr_vk: Option<H160>,
|
||||
},
|
||||
|
||||
/// Print the proof in hexadecimal
|
||||
#[command(name = "print-proof-hex")]
|
||||
PrintProofHex {
|
||||
/// The path to the proof file
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
109
src/execute.rs
109
src/execute.rs
@@ -3,8 +3,6 @@ use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::Commands;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::H160Flag;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(unused_imports)]
|
||||
@@ -23,6 +21,8 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::RunArgs;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -178,7 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
div_rebasing,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -187,7 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -202,7 +202,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifier {
|
||||
Commands::CreateEVMVerifier {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -217,7 +217,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CreateEvmVK {
|
||||
Commands::CreateEVMVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -225,14 +225,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmDataAttestation {
|
||||
Commands::CreateEVMDataAttestation {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
data,
|
||||
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifierAggr {
|
||||
Commands::CreateEVMVerifierAggr {
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
@@ -270,7 +270,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
compress_selectors,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
Commands::SetupTestEVMData {
|
||||
data,
|
||||
compiled_circuit,
|
||||
test_data,
|
||||
@@ -366,8 +366,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
reduced_srs,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
proof_path,
|
||||
@@ -434,13 +433,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::VerifyEvm {
|
||||
Commands::VerifyEVM {
|
||||
proof_path,
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
|
||||
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,7 +628,7 @@ pub(crate) async fn gen_witness(
|
||||
);
|
||||
|
||||
if let Some(output_path) = output {
|
||||
witness.save(output_path)?;
|
||||
serde_json::to_writer(&File::create(output_path)?, &witness)?;
|
||||
}
|
||||
|
||||
// print the witness in debug
|
||||
@@ -726,11 +726,11 @@ impl AccuracyResults {
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error);
|
||||
abs_errors.extend(abs_error);
|
||||
squared_errors.extend(squared_error);
|
||||
percentage_errors.extend(percentage_error);
|
||||
abs_percentage_errors.extend(abs_percentage_error);
|
||||
errors.extend(error.into_iter());
|
||||
abs_errors.extend(abs_error.into_iter());
|
||||
squared_errors.extend(squared_error.into_iter());
|
||||
percentage_errors.extend(percentage_error.into_iter());
|
||||
abs_percentage_errors.extend(abs_percentage_error.into_iter());
|
||||
}
|
||||
|
||||
let mean_percent_error =
|
||||
@@ -780,7 +780,6 @@ impl AccuracyResults {
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: PathBuf,
|
||||
@@ -789,7 +788,7 @@ pub(crate) fn calibrate(
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
div_rebasing: Option<bool>,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use std::collections::HashMap;
|
||||
@@ -827,11 +826,14 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
match target {
|
||||
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
|
||||
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
|
||||
}
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
vec![div_rebasing]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
@@ -1170,6 +1172,16 @@ pub(crate) fn mock(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
for instance in proof.instances {
|
||||
println!("{:?}", instance);
|
||||
}
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
info!("0x{}", hex_str);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
#[cfg(feature = "render")]
|
||||
pub(crate) fn render(
|
||||
model: PathBuf,
|
||||
@@ -1388,10 +1400,10 @@ pub(crate) async fn deploy_evm(
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
addr_verifier: H160Flag,
|
||||
addr_verifier: H160,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
addr_da: Option<H160>,
|
||||
addr_vk: Option<H160>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
check_solc_requirement();
|
||||
@@ -1401,20 +1413,14 @@ pub(crate) async fn verify_evm(
|
||||
let result = if let Some(addr_da) = addr_da {
|
||||
verify_proof_with_data_attestation(
|
||||
proof.clone(),
|
||||
addr_verifier.into(),
|
||||
addr_da.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
addr_verifier,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
verify_proof_via_solidity(
|
||||
proof.clone(),
|
||||
addr_verifier.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
verify_proof_via_solidity(proof.clone(), addr_verifier, addr_vk, rpc_url.as_deref()).await?
|
||||
};
|
||||
|
||||
info!("Solidity verification result: {}", result);
|
||||
@@ -1570,14 +1576,14 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
use crate::pfsys::ProofType;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
addr: H160,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
check_solc_requirement();
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
update_account_calls(addr, data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
@@ -1722,7 +1728,6 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1753,7 +1758,6 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1790,7 +1794,6 @@ pub(crate) fn fuzz(
|
||||
proof.clone(),
|
||||
bad_vk,
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1822,7 +1825,6 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1858,7 +1860,6 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -2044,29 +2045,15 @@ pub(crate) fn verify(
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
reduced_srs: bool,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let params = if reduced_srs {
|
||||
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
|
||||
} else {
|
||||
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
|
||||
};
|
||||
|
||||
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
let vk =
|
||||
load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?;
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
"verify took {}.{}",
|
||||
@@ -2090,7 +2077,7 @@ pub(crate) fn verify_aggr(
|
||||
let strategy = AccumulatorStrategy::new(params.verifier_params());
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(vk_path, ())?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy, 1 << logrows);
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy);
|
||||
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::circuit::InputType;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use postgres::{Client, NoTls};
|
||||
@@ -16,8 +15,6 @@ use pyo3::types::PyDict;
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -493,20 +490,16 @@ impl GraphData {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
reader.read_to_string(&mut buf)?;
|
||||
let graph_input = serde_json::from_str(&buf)?;
|
||||
Ok(graph_input)
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to open input at {}", path.display()))?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
217
src/graph/mod.rs
217
src/graph/mod.rs
@@ -16,7 +16,6 @@ use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSource;
|
||||
@@ -29,17 +28,14 @@ use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
use log::{debug, error, trace, warn};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
@@ -50,6 +46,7 @@ use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::ops::Deref;
|
||||
use thiserror::Error;
|
||||
pub use utilities::*;
|
||||
@@ -64,20 +61,6 @@ pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
/// Max circuit area
|
||||
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
|
||||
if let Ok(max_circuit_area) = std::env::var("EZKL_MAX_CIRCUIT_AREA") {
|
||||
Some(max_circuit_area.parse().unwrap_or(0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
@@ -208,41 +191,42 @@ impl GraphWitness {
|
||||
output_scales: Vec<crate::Scale>,
|
||||
visibility: VarVisibility,
|
||||
) {
|
||||
let mut pretty_elements = PrettyElements {
|
||||
rescaled_inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
rescaled_outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
let mut pretty_elements = PrettyElements::default();
|
||||
pretty_elements.rescaled_inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
|
||||
pretty_elements.rescaled_outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
|
||||
if let Some(processed_inputs) = self.processed_inputs.clone() {
|
||||
pretty_elements.processed_inputs = processed_inputs
|
||||
@@ -308,20 +292,16 @@ impl GraphWitness {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open(path.clone())
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load model at {}", path.display()))?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
|
||||
serde_json::to_writer(writer, &self).map_err(|e| e.into())
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
@@ -476,33 +456,22 @@ impl GraphSettings {
|
||||
instances
|
||||
}
|
||||
|
||||
/// calculate the log2 of the total number of instances
|
||||
pub fn log2_total_instances(&self) -> u32 {
|
||||
let sum = self.total_instances().iter().sum::<usize>();
|
||||
|
||||
// max between 1 and the log2 of the sums
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// save params to file
|
||||
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
|
||||
// buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, &self).map_err(|e| {
|
||||
error!("failed to save settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
let encoded = serde_json::to_string(&self)?;
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(encoded.as_bytes())
|
||||
}
|
||||
/// load params from file
|
||||
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
|
||||
// buf reader
|
||||
let reader =
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
let mut file = std::fs::File::open(path).map_err(|e| {
|
||||
error!("failed to open settings file at {}", e);
|
||||
e
|
||||
})?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
let res = serde_json::from_str(&data)?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
@@ -561,7 +530,6 @@ impl GraphSettings {
|
||||
pub struct GraphConfig {
|
||||
model_config: ModelConfig,
|
||||
module_configs: ModuleConfigs,
|
||||
circuit_size: CircuitSize,
|
||||
}
|
||||
|
||||
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
|
||||
@@ -598,7 +566,7 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -606,10 +574,11 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
// read bytes from file
|
||||
let f = std::fs::File::open(path)?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = std::fs::metadata(&path)?;
|
||||
let mut buffer = vec![0; metadata.len() as usize];
|
||||
f.read_exact(&mut buffer)?;
|
||||
let result = bincode::deserialize(&buffer)?;
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -624,17 +593,6 @@ pub enum TestDataSource {
|
||||
OnChain,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestDataSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TestDataSource::File => write!(f, "file"),
|
||||
TestDataSource::OnChain => write!(f, "on-chain"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TestDataSource {}
|
||||
|
||||
impl From<String> for TestDataSource {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -851,7 +809,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
debug!("input scales: {:?}", scales);
|
||||
info!("input scales: {:?}", scales);
|
||||
|
||||
match &data.input_data {
|
||||
DataSource::File(file_data) => {
|
||||
@@ -870,7 +828,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
debug!("input scales: {:?}", scales);
|
||||
info!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
@@ -1070,7 +1028,7 @@ impl GraphCircuit {
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
min_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
error!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
@@ -1090,7 +1048,7 @@ impl GraphCircuit {
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
error!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
@@ -1156,7 +1114,7 @@ impl GraphCircuit {
|
||||
|
||||
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
|
||||
|
||||
debug!(
|
||||
info!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
self.settings().run_args.lookup_range,
|
||||
self.settings().run_args.logrows
|
||||
@@ -1408,6 +1366,7 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
struct CircuitSize {
|
||||
num_instances: usize,
|
||||
@@ -1415,22 +1374,20 @@ struct CircuitSize {
|
||||
num_fixed: usize,
|
||||
num_challenges: usize,
|
||||
num_selectors: usize,
|
||||
logrows: u32,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl CircuitSize {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>) -> Self {
|
||||
CircuitSize {
|
||||
num_instances: cs.num_instance_columns(),
|
||||
num_advice_columns: cs.num_advice_columns(),
|
||||
num_fixed: cs.num_fixed_columns(),
|
||||
num_challenges: cs.num_challenges(),
|
||||
num_selectors: cs.num_selectors(),
|
||||
logrows,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
@@ -1441,25 +1398,6 @@ impl CircuitSize {
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// number of columns
|
||||
pub fn num_columns(&self) -> usize {
|
||||
self.num_instances + self.num_advice_columns + self.num_fixed
|
||||
}
|
||||
|
||||
/// area of the circuit
|
||||
pub fn area(&self) -> usize {
|
||||
self.num_columns() * (1 << self.logrows)
|
||||
}
|
||||
|
||||
/// area less than max
|
||||
pub fn area_less_than_max(&self) -> bool {
|
||||
if EZKL_MAX_CIRCUIT_AREA.is_some() {
|
||||
self.area() < EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fp> for GraphCircuit {
|
||||
@@ -1534,12 +1472,10 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
(cs.degree() as f32).log2().ceil()
|
||||
);
|
||||
|
||||
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
info!(
|
||||
"circuit size: \n {}",
|
||||
circuit_size
|
||||
CircuitSize::from_cs(cs)
|
||||
.as_json()
|
||||
.unwrap()
|
||||
.to_colored_json_auto()
|
||||
@@ -1549,7 +1485,6 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
GraphConfig {
|
||||
model_config,
|
||||
module_configs,
|
||||
circuit_size,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1562,16 +1497,6 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), PlonkError> {
|
||||
// check if the circuit area is less than the max
|
||||
if !config.circuit_size.area_less_than_max() {
|
||||
error!(
|
||||
"circuit area {} is larger than the max allowed area {}",
|
||||
config.circuit_size.area(),
|
||||
EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
);
|
||||
return Err(PlonkError::Synthesis);
|
||||
}
|
||||
|
||||
trace!("Setting input in synthesize");
|
||||
let input_vis = &self.settings().run_args.input_visibility;
|
||||
let output_vis = &self.settings().run_args.output_visibility;
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -56,8 +57,6 @@ use unzip_n::unzip_n;
|
||||
|
||||
unzip_n!(pub 3);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
|
||||
/// The result of a forward pass.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForwardResult {
|
||||
@@ -69,16 +68,6 @@ pub struct ForwardResult {
|
||||
pub min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
fn from(res: DummyPassRes) -> Self {
|
||||
Self {
|
||||
outputs: res.outputs,
|
||||
max_lookup_inputs: res.max_lookup_inputs,
|
||||
min_lookup_inputs: res.min_lookup_inputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelConfig {
|
||||
@@ -104,12 +93,6 @@ pub struct DummyPassRes {
|
||||
pub lookup_ops: HashSet<LookupOp>,
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i128,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
@@ -495,7 +478,7 @@ impl Model {
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
let instance_shapes = self.instance_shapes()?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
info!(
|
||||
"{} {} {}",
|
||||
"model has".blue(),
|
||||
instance_shapes.len().to_string().blue(),
|
||||
@@ -503,25 +486,7 @@ impl Model {
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = self
|
||||
.graph
|
||||
.input_shapes()?
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs)?;
|
||||
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -557,12 +522,206 @@ impl Model {
|
||||
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
|
||||
/// * `run_args` - [RunArgs]
|
||||
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
|
||||
Ok(res.into())
|
||||
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
|
||||
let mut max_lookup_inputs = 0;
|
||||
let mut min_lookup_inputs = 0;
|
||||
|
||||
let input_shapes = self.graph.input_shapes()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
let mut input = model_inputs[i].clone();
|
||||
input.reshape(&input_shapes[i])?;
|
||||
results.insert(input_idx, vec![input]);
|
||||
}
|
||||
|
||||
for (idx, n) in self.graph.nodes.iter() {
|
||||
let mut inputs = vec![];
|
||||
if n.is_input() {
|
||||
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
|
||||
inputs.push(t);
|
||||
} else {
|
||||
for (idx, outlet) in n.inputs().iter() {
|
||||
match results.get(&idx) {
|
||||
Some(value) => inputs.push(value[*outlet].clone()),
|
||||
None => return Err(Box::new(GraphError::MissingNode(*idx))),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("executing {}: {}", idx, n.as_str());
|
||||
debug!("dims: {:?}", n.out_dims());
|
||||
debug!(
|
||||
"input_dims: {:?}",
|
||||
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
debug!("input nodes: {:?}", n.inputs());
|
||||
|
||||
if n.is_lookup() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &inputs {
|
||||
max = max.max(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.max()
|
||||
.ok_or("missing max")?,
|
||||
);
|
||||
min = min.min(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.min()
|
||||
.ok_or("missing min")?,
|
||||
);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("max lookup inputs: {}", max);
|
||||
debug!("min lookup inputs: {}", min);
|
||||
}
|
||||
|
||||
match n {
|
||||
NodeType::Node(n) => {
|
||||
// execute the op
|
||||
let start = instant::Instant::now();
|
||||
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
|
||||
res.output.reshape(&n.out_dims)?;
|
||||
let elapsed = start.elapsed();
|
||||
trace!("op took: {:?}", elapsed);
|
||||
// see if any of the intermediate lookup calcs are the max
|
||||
if !res.intermediate_lookups.is_empty() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &res.intermediate_lookups {
|
||||
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
|
||||
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("intermediate max lookup inputs: {}", max);
|
||||
debug!("intermediate min lookup inputs: {}", min);
|
||||
}
|
||||
debug!(
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
|
||||
idx,
|
||||
res.output.map(crate::fieldutils::felt_to_i32).show(),
|
||||
res.output
|
||||
.map(|x| crate::fieldutils::felt_to_f64(x)
|
||||
/ scale_to_multiplier(n.out_scale))
|
||||
.show(),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
|
||||
n.out_scale
|
||||
);
|
||||
results.insert(idx, vec![res.output]);
|
||||
}
|
||||
NodeType::SubGraph {
|
||||
model,
|
||||
output_mappings,
|
||||
input_mappings,
|
||||
inputs: input_tuple,
|
||||
..
|
||||
} => {
|
||||
let orig_inputs = inputs.clone();
|
||||
let input_mappings = input_mappings.clone();
|
||||
|
||||
let input_dims = inputs.iter().map(|inp| inp.dims());
|
||||
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
|
||||
|
||||
debug!(
|
||||
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
|
||||
num_iter, input_tuple, model.graph.inputs
|
||||
);
|
||||
|
||||
debug!("input_mappings: {:?}", input_mappings);
|
||||
|
||||
let mut full_results: Vec<Tensor<Fp>> = vec![];
|
||||
|
||||
for i in 0..num_iter {
|
||||
// replace the Stacked input with the current chunk iter
|
||||
for ((mapping, inp), og_input) in
|
||||
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
|
||||
{
|
||||
if let InputMapping::Stacked { axis, chunk } = mapping {
|
||||
let start = i * chunk;
|
||||
let end = (i + 1) * chunk;
|
||||
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
|
||||
*inp = t;
|
||||
}
|
||||
}
|
||||
|
||||
let res = model.forward(&inputs)?;
|
||||
// recursively get the max lookup inputs for subgraphs
|
||||
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
|
||||
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
|
||||
|
||||
let mut outlets = BTreeMap::new();
|
||||
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
|
||||
for mapping in mappings {
|
||||
match mapping {
|
||||
OutputMapping::Single { outlet, .. } => {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
OutputMapping::Stacked { outlet, axis, .. } => {
|
||||
if !full_results.is_empty() {
|
||||
let stacked_res = crate::tensor::ops::concat(
|
||||
&[&full_results[*outlet], &outlet_res],
|
||||
*axis,
|
||||
)?;
|
||||
|
||||
outlets.insert(outlet, stacked_res);
|
||||
} else {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
full_results = outlets.into_values().collect_vec();
|
||||
|
||||
let output_states = output_state_idx(output_mappings);
|
||||
let input_states = input_state_idx(&input_mappings);
|
||||
|
||||
assert_eq!(input_states.len(), output_states.len());
|
||||
|
||||
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
|
||||
inputs[*input_idx] = full_results[output_idx].clone();
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"------------ output subgraph node {}: {:?}",
|
||||
idx,
|
||||
full_results
|
||||
.iter()
|
||||
.map(|x|
|
||||
// convert to tensor i32
|
||||
x.map(crate::fieldutils::felt_to_i32).show())
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
results.insert(idx, full_results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let output_nodes = self.graph.outputs.iter();
|
||||
debug!(
|
||||
"model outputs are nodes: {:?}",
|
||||
output_nodes.clone().collect_vec()
|
||||
);
|
||||
let outputs = output_nodes
|
||||
.map(|(idx, outlet)| {
|
||||
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = ForwardResult {
|
||||
outputs,
|
||||
max_lookup_inputs,
|
||||
min_lookup_inputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Loads an Onnx model from a specified path.
|
||||
@@ -574,7 +733,7 @@ impl Model {
|
||||
fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<TractResult, Box<dyn Error>> {
|
||||
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
@@ -613,7 +772,7 @@ impl Model {
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
let symbol = model.symbol_table.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
info!("set {} to {}", symbol, value);
|
||||
}
|
||||
|
||||
// Note: do not optimize the model, as the layout will depend on underlying hardware
|
||||
@@ -1014,8 +1173,8 @@ impl Model {
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[2];
|
||||
let index = &vars.advices[1];
|
||||
let output = &vars.advices[1];
|
||||
let index = &vars.advices[2];
|
||||
for op in required_lookups {
|
||||
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
|
||||
}
|
||||
@@ -1134,7 +1293,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
info!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
num_rows.to_string().blue(),
|
||||
@@ -1176,29 +1335,18 @@ impl Model {
|
||||
};
|
||||
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
region.total_constants()
|
||||
);
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
@@ -1341,13 +1489,28 @@ impl Model {
|
||||
pub fn dummy_layout(
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
input_shapes: &[Vec<usize>],
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
info!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = input_shapes
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
results.insert(*input_idx, vec![inputs[i].clone()]);
|
||||
@@ -1382,12 +1545,12 @@ impl Model {
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let _ = outputs
|
||||
.iter()
|
||||
.into_iter()
|
||||
.zip(comparator)
|
||||
.map(|(o, c)| {
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[o.clone(), c],
|
||||
&[o, c],
|
||||
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
|
||||
)
|
||||
})
|
||||
@@ -1403,7 +1566,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
info!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
region.row().to_string().blue(),
|
||||
@@ -1412,23 +1575,12 @@ impl Model {
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
outputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
|
||||
@@ -148,7 +148,7 @@ impl RebaseScale {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier,
|
||||
multiplier: multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
@@ -219,6 +219,8 @@ impl Op<Fp> for RebaseScale {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
res.intermediate_lookups
|
||||
.extend(rebase_res.intermediate_lookups);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
@@ -497,7 +499,6 @@ impl Node {
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
other_nodes: &mut BTreeMap<usize, super::NodeType>,
|
||||
|
||||
@@ -439,16 +439,17 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -477,16 +478,17 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
use crate::tensor::{ValTensor, VarTensor};
|
||||
@@ -15,7 +14,6 @@ use pyo3::{
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -42,33 +40,6 @@ pub enum Visibility {
|
||||
Fixed,
|
||||
}
|
||||
|
||||
impl Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed {
|
||||
hash_is_public,
|
||||
outlets,
|
||||
} => {
|
||||
if *hash_is_public {
|
||||
write!(f, "hashed/public")
|
||||
} else {
|
||||
write!(f, "hashed/private/{}", outlets.iter().join(","))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Visibility {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
@@ -231,6 +202,17 @@ impl Visibility {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
impl std::fmt::Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed { .. } => write!(f, "hashed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
@@ -428,7 +410,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
num_constants: usize,
|
||||
module_requires_fixed: bool,
|
||||
) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
info!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
|
||||
|
||||
64
src/lib.rs
64
src/lib.rs
@@ -7,6 +7,7 @@
|
||||
overflowing_literals,
|
||||
path_statements,
|
||||
patterns_in_fns_without_body,
|
||||
private_in_public,
|
||||
unconditional_recursion,
|
||||
unused,
|
||||
unused_allocation,
|
||||
@@ -32,7 +33,6 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
|
||||
pub mod circuit;
|
||||
@@ -71,34 +71,11 @@ pub mod tensor;
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub mod wasm;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
pub type Scale = i32;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
// Buf writer capacity
|
||||
lazy_static! {
|
||||
/// The capacity of the buffer used for writing to disk
|
||||
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_KEY_FORMAT: &str = "raw-bytes";
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
#[arg(short = 'T', long, default_value = "0")]
|
||||
@@ -113,7 +90,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1")]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
|
||||
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
@@ -122,7 +99,7 @@ pub struct RunArgs {
|
||||
#[arg(short = 'N', long, default_value = "2")]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
@@ -204,15 +181,34 @@ fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr + std::fmt::Debug,
|
||||
T: std::str::FromStr,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
U: std::str::FromStr + std::fmt::Debug,
|
||||
U: std::str::FromStr,
|
||||
U::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
.find('=')
|
||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
||||
}
|
||||
|
||||
/// Parse a tuple
|
||||
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr + Clone,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
|
||||
|
||||
let res = res
|
||||
.map(|x| {
|
||||
// remove blank space
|
||||
let x = x.trim();
|
||||
x.parse::<T>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if res.len() != 2 {
|
||||
return Err("invalid tuple".into());
|
||||
}
|
||||
Ok((res[0].clone(), res[1].clone()))
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation::PoseidonTranscript;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
@@ -40,19 +39,9 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error as thisError;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
match s {
|
||||
"processed" => halo2_proofs::SerdeFormat::Processed,
|
||||
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
|
||||
_ => panic!("invalid serde format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(
|
||||
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
|
||||
@@ -63,25 +52,6 @@ pub enum ProofType {
|
||||
ForAggr,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProofType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ProofType::Single => "single",
|
||||
ProofType::ForAggr => "for-aggr",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for ProofType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProofType> for TranscriptType {
|
||||
fn from(val: ProofType) -> Self {
|
||||
match val {
|
||||
@@ -184,21 +154,6 @@ pub enum TranscriptType {
|
||||
EVM,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TranscriptType {}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for TranscriptType {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -360,7 +315,7 @@ where
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let mut writer = BufWriter::new(file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -373,10 +328,8 @@ where
|
||||
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
|
||||
{
|
||||
trace!("reading proof");
|
||||
let file = std::fs::File::open(proof_path)?;
|
||||
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let proof: Self = serde_json::from_reader(reader)?;
|
||||
Ok(proof)
|
||||
let data = std::fs::read_to_string(proof_path)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -602,7 +555,6 @@ where
|
||||
verifier_params,
|
||||
pk.get_vk(),
|
||||
strategy,
|
||||
verifier_params.n(),
|
||||
)?;
|
||||
}
|
||||
let elapsed = now.elapsed();
|
||||
@@ -690,7 +642,6 @@ pub fn verify_proof_circuit<
|
||||
params: &'params Scheme::ParamsVerifier,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
@@ -711,7 +662,7 @@ where
|
||||
trace!("instances {:?}", instances);
|
||||
|
||||
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
|
||||
}
|
||||
|
||||
/// Loads a [VerifyingKey] at `path`.
|
||||
@@ -727,10 +678,10 @@ where
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let mut reader = BufReader::new(f);
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
@@ -749,10 +700,10 @@ where
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let mut reader = BufReader::new(f);
|
||||
ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
@@ -761,7 +712,7 @@ where
|
||||
/// Saves a [ProvingKey] to `path`.
|
||||
pub fn save_pk<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
vk: &ProvingKey<Scheme::Curve>,
|
||||
) -> Result<(), io::Error>
|
||||
where
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
@@ -769,8 +720,8 @@ where
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -786,8 +737,8 @@ where
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -799,7 +750,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let mut writer = BufWriter::new(f);
|
||||
params.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
@@ -889,7 +840,6 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
vk: &VerifyingKey<G1Affine>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
@@ -899,7 +849,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
>(&proof, params, vk, strategy),
|
||||
TranscriptType::Poseidon => verify_proof_circuit::<
|
||||
Fr,
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
@@ -907,7 +857,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
>(&proof, params, vk, strategy),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@ use crate::graph::{
|
||||
use crate::pfsys::evm::aggregation::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
|
||||
TranscriptType,
|
||||
Snark, TranscriptType,
|
||||
};
|
||||
use crate::RunArgs;
|
||||
use ethers::types::H160;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
@@ -25,6 +26,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
@@ -523,7 +525,7 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
div_rebasing = None,
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
data: PathBuf,
|
||||
@@ -534,7 +536,7 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
div_rebasing: Option<bool>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
@@ -544,7 +546,7 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
div_rebasing,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -687,23 +689,14 @@ fn prove(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
srs_path=None,
|
||||
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
|
||||
))]
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
non_reduced_srs: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::verify(
|
||||
proof_path,
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
non_reduced_srs,
|
||||
)
|
||||
.map_err(|e| {
|
||||
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
|
||||
let err_str = format!("Failed to run verify: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -1029,15 +1022,24 @@ fn verify_evm(
|
||||
addr_da: Option<&str>,
|
||||
addr_vk: Option<&str>,
|
||||
) -> Result<bool, PyErr> {
|
||||
let addr_verifier = H160Flag::from(addr_verifier);
|
||||
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_da = if let Some(addr_da) = addr_da {
|
||||
let addr_da = H160Flag::from(addr_da);
|
||||
let addr_da = H160::from_str(addr_da).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Some(addr_da)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let addr_vk = if let Some(addr_vk) = addr_vk {
|
||||
let addr_vk = H160Flag::from(addr_vk);
|
||||
let addr_vk = H160::from_str(addr_vk).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Some(addr_vk)
|
||||
} else {
|
||||
None
|
||||
@@ -1095,6 +1097,16 @@ fn create_evm_verifier_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[pyfunction(signature = (proof_path))]
|
||||
fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
|
||||
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
@@ -1133,6 +1145,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
|
||||
@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner = self.inner.clone();
|
||||
let remainder = self.len() % n;
|
||||
if remainder != 0 {
|
||||
inner.resize(self.len() + n - remainder, pad);
|
||||
inner.resize(self.len() + n - remainder, T::zero().unwrap());
|
||||
}
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
}
|
||||
|
||||
@@ -243,7 +243,7 @@ pub fn and<
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = equals(&a, &b).unwrap();
|
||||
/// let result = equals(&a, &b).unwrap().0;
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
@@ -260,7 +260,7 @@ pub fn equals<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let a = a.clone();
|
||||
let b = b.clone();
|
||||
|
||||
@@ -268,7 +268,7 @@ pub fn equals<
|
||||
|
||||
let result = nonlinearities::kronecker_delta(&diff);
|
||||
|
||||
Ok(result)
|
||||
Ok((result, vec![diff]))
|
||||
}
|
||||
|
||||
/// Greater than operation.
|
||||
@@ -289,7 +289,7 @@ pub fn equals<
|
||||
/// ).unwrap();
|
||||
/// let result = greater(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
pub fn greater<
|
||||
T: TensorType
|
||||
@@ -302,7 +302,7 @@ pub fn greater<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -311,7 +311,7 @@ pub fn greater<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok(mask)
|
||||
Ok((mask, vec![mask_inter]))
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -332,7 +332,7 @@ pub fn greater<
|
||||
/// ).unwrap();
|
||||
/// let result = greater_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
pub fn greater_equal<
|
||||
T: TensorType
|
||||
@@ -345,7 +345,7 @@ pub fn greater_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -354,7 +354,7 @@ pub fn greater_equal<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok(mask)
|
||||
Ok((mask, vec![mask_inter]))
|
||||
}
|
||||
|
||||
/// Less than to operation.
|
||||
@@ -375,7 +375,7 @@ pub fn greater_equal<
|
||||
/// ).unwrap();
|
||||
/// let result = less(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less<
|
||||
@@ -389,7 +389,7 @@ pub fn less<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater(b, a)
|
||||
}
|
||||
@@ -412,7 +412,7 @@ pub fn less<
|
||||
/// ).unwrap();
|
||||
/// let result = less_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less_equal<
|
||||
@@ -426,7 +426,7 @@ pub fn less_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater_equal(b, a)
|
||||
}
|
||||
@@ -2300,12 +2300,12 @@ pub fn deconv<
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
///
|
||||
/// // This time with normalization
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
/// ```
|
||||
@@ -2315,7 +2315,7 @@ pub fn sumpool(
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
normalize: bool,
|
||||
) -> Result<Tensor<i128>, TensorError> {
|
||||
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
|
||||
let image_dims = image.dims();
|
||||
let batch_size = image_dims[0];
|
||||
let image_channels = image_dims[1];
|
||||
@@ -2345,12 +2345,15 @@ pub fn sumpool(
|
||||
let mut combined = res.combine()?;
|
||||
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
let mut inter = vec![];
|
||||
|
||||
if normalize {
|
||||
inter.push(combined.clone());
|
||||
let norm = kernel.len();
|
||||
combined = nonlinearities::const_div(&combined, norm as f64);
|
||||
}
|
||||
|
||||
Ok(combined)
|
||||
Ok((combined, inter))
|
||||
}
|
||||
|
||||
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
|
||||
@@ -3045,7 +3048,11 @@ pub mod nonlinearities {
|
||||
}
|
||||
|
||||
/// softmax layout
|
||||
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
|
||||
pub fn softmax_axes(
|
||||
a: &Tensor<i128>,
|
||||
scale: f64,
|
||||
axes: &[usize],
|
||||
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
// we want this to be as small as possible so we set the output scale to 1
|
||||
let dims = a.dims();
|
||||
|
||||
@@ -3053,6 +3060,8 @@ pub mod nonlinearities {
|
||||
return softmax(a, scale);
|
||||
}
|
||||
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
let cartesian_coord = dims[..dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
@@ -3075,7 +3084,8 @@ pub mod nonlinearities {
|
||||
|
||||
let res = softmax(&softmax_input, scale);
|
||||
|
||||
outputs.push(res);
|
||||
outputs.push(res.0);
|
||||
intermediate_values.extend(res.1);
|
||||
}
|
||||
|
||||
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
|
||||
@@ -3083,7 +3093,7 @@ pub mod nonlinearities {
|
||||
.combine()
|
||||
.unwrap();
|
||||
res.reshape(dims).unwrap();
|
||||
res
|
||||
(res, intermediate_values)
|
||||
}
|
||||
|
||||
/// Applies softmax
|
||||
@@ -3100,20 +3110,24 @@ pub mod nonlinearities {
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = softmax(&x, 128.0);
|
||||
/// let result = softmax(&x, 128.0).0;
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
intermediate_values.push(a.clone());
|
||||
|
||||
let exp = exp(a, scale);
|
||||
|
||||
let sum = sum(&exp).unwrap();
|
||||
intermediate_values.push(sum.clone());
|
||||
let inv_denom = recip(&sum, scale, scale);
|
||||
|
||||
(exp * inv_denom).unwrap()
|
||||
((exp * inv_denom).unwrap(), intermediate_values)
|
||||
}
|
||||
|
||||
/// Applies range_check_percent
|
||||
|
||||
@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.pad_to_zero_rem(n, pad)?;
|
||||
*v = v.pad_to_zero_rem(n)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
@@ -871,30 +871,3 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.map(|x| match x {
|
||||
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
ValType::AssignedValue(v.value_field().invert())
|
||||
}
|
||||
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
|
||||
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
|
||||
});
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,10 @@ pub enum VarTensor {
|
||||
impl VarTensor {
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
match self {
|
||||
VarTensor::Advice { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
19
src/wasm.rs
19
src/wasm.rs
@@ -311,19 +311,13 @@ pub fn verify(
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings.clone(),
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
snark,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -393,6 +387,15 @@ pub fn prove(
|
||||
.into_bytes())
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
// VALIDATION FUNCTIONS
|
||||
|
||||
/// Witness file validation
|
||||
|
||||
@@ -192,86 +192,86 @@ mod native_tests {
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
"1l_mlp", //0
|
||||
"1l_mlp",
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad", // 5
|
||||
"1l_pad",
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax", //10
|
||||
"1l_softmax",
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu", //15
|
||||
"1l_relu",
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_small",
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc", //25
|
||||
"3l_relu_conv_fc",
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"1l_elu", //30
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"1l_upsample", //35
|
||||
"1l_identity",
|
||||
"idolmodel",
|
||||
"trig",
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"lstm", //40
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // 45
|
||||
"decision_tree", // "variable_cnn",
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"xgboost", //50
|
||||
"lightgbm",
|
||||
"hummingbird_decision_tree",
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
"less", //55
|
||||
"less",
|
||||
"xgboost_reg",
|
||||
"1l_powf",
|
||||
"scatter_elements",
|
||||
"1l_linear",
|
||||
"linear_regression", //60
|
||||
"1l_linear", //60
|
||||
"linear_regression",
|
||||
"sklearn_mlp",
|
||||
"1l_mean",
|
||||
"rounding_ops",
|
||||
// "mean_as_constrain",
|
||||
"arange",
|
||||
"layernorm", //65
|
||||
"layernorm",
|
||||
"bitwise_ops",
|
||||
"blackman_window",
|
||||
"softsign", //68
|
||||
"softsign", //70
|
||||
"softplus",
|
||||
"selu", //70
|
||||
"selu",
|
||||
"hard_sigmoid",
|
||||
"log_softmax",
|
||||
"eye",
|
||||
"ltsf",
|
||||
"remainder", //75
|
||||
"remainder",
|
||||
"bitshift",
|
||||
];
|
||||
|
||||
@@ -560,7 +560,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1321,7 +1321,7 @@ mod native_tests {
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size->{}", batch_size),
|
||||
format!("--variables=batch_size={}", batch_size),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
@@ -1402,7 +1402,6 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn accuracy_measurement(
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
@@ -1456,7 +1455,7 @@ mod native_tests {
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
"-O",
|
||||
format!("{}/{}/render.png", test_dir, example_name).as_str(),
|
||||
"--lookup-range=-32768->32768",
|
||||
"--lookup-range=(-32768,32768)",
|
||||
"-K=17",
|
||||
])
|
||||
.status()
|
||||
@@ -1746,30 +1745,6 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// load settings file
|
||||
let settings =
|
||||
std::fs::read_to_string(settings_path.clone()).expect("failed to read settings file");
|
||||
|
||||
let graph_settings = serde_json::from_str::<GraphSettings>(&settings)
|
||||
.expect("failed to parse settings file");
|
||||
|
||||
// get_srs for the graph_settings_num_instances
|
||||
download_srs(graph_settings.log2_total_instances());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
"--reduced-srs",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
|
||||
@@ -131,7 +131,7 @@ mod py_tests {
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"set_membership.ipynb",
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
|
||||
@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -147,13 +147,13 @@ def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -183,7 +183,7 @@ def test_model_compile():
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
compiled_model_path = os.path.join(
|
||||
@@ -205,7 +205,7 @@ def test_forward():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
@@ -392,7 +392,9 @@ def test_prove_evm():
|
||||
assert res['transcript_type'] == 'EVM'
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
|
||||
res = ezkl.print_proof_hex(proof_path)
|
||||
# to figure out a better way of testing print_proof_hex
|
||||
assert type(res) == str
|
||||
|
||||
|
||||
def test_create_evm_verifier():
|
||||
|
||||
@@ -9,8 +9,8 @@ mod wasm32 {
|
||||
use ezkl::pfsys;
|
||||
use ezkl::wasm::{
|
||||
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove,
|
||||
settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
|
||||
prove, settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
|
||||
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
@@ -258,6 +258,15 @@ mod wasm32 {
|
||||
assert!(value);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn print_proof_hex_test() {
|
||||
let proof = printProofHex(wasm_bindgen::Clamped(PROOF.to_vec()))
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert!(proof.len() > 0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn verify_validations() {
|
||||
// Run witness validation on network (should fail)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user