Compare commits

...

15 Commits

42 changed files with 1396 additions and 1077 deletions

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Checkout repo

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -189,7 +189,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -199,7 +199,7 @@ jobs:
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -212,7 +212,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -229,13 +229,15 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
- name: kzg inputs
@@ -284,7 +286,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -310,7 +312,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -343,7 +345,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -351,7 +353,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -411,11 +413,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -445,7 +447,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -455,7 +457,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: fuzz tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
# - name: fuzz tests
@@ -468,7 +470,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -486,7 +488,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -503,7 +505,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -520,7 +522,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -530,7 +532,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -541,7 +543,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -563,7 +565,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Install solc
@@ -571,7 +573,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Run pytest
@@ -587,7 +589,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -618,7 +620,7 @@ jobs:
python-version: "3.9"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -628,7 +630,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -30,7 +30,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e

558
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
@@ -34,10 +34,13 @@ bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
instant = { version = "0.1" }
@@ -158,6 +161,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidi
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']

View File

@@ -6,6 +6,7 @@ use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -489,6 +490,7 @@ pub fn runconv() {
strategy,
pi_for_real_prover,
&mut transcript,
params.n(),
);
assert!(verify.is_ok());

View File

@@ -11,8 +11,8 @@ use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
use log::{error, info};
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "icicle")]
@@ -25,6 +25,7 @@ use std::error::Error;
pub async fn main() -> Result<(), Box<dyn Error>> {
let args = Cli::parse();
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
@@ -32,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
} else {
info!("Running with CPU");
}
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
let res = run(args.command).await;
match &res {
Ok(_) => info!("succeeded"),
@@ -44,7 +45,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
#[cfg(target_arch = "wasm32")]
pub fn main() {}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
fn banner() {
let ell: Vec<&str> = vec![
"for Neural Networks",

View File

@@ -41,7 +41,7 @@ pub struct KZGChip {
}
impl KZGChip {
/// Returns the number of inputs to the hash function
/// Commit to the message using the KZG commitment scheme
pub fn commit(
message: Vec<Fp>,
degree: u32,

View File

@@ -15,7 +15,7 @@ use halo2_proofs::{
Instance, Selector, TableColumn,
},
};
use log::{trace, warn};
use log::{debug, trace};
/// A simple [`FloorPlanner`] that performs minimal optimizations.
#[derive(Debug)]
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
Error::Synthesis
})?;
if !self.regions.contains_key(&index) {
warn!("spawning module {}", index)
debug!("spawning module {}", index)
};
self.current_module = index;
}

View File

@@ -16,6 +16,7 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use crate::{
circuit::ops::base::BaseOp,
@@ -61,6 +62,22 @@ pub enum CheckMode {
UNSAFE,
}
impl std::fmt::Display for CheckMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckMode::SAFE => write!(f, "safe"),
CheckMode::UNSAFE => write!(f, "unsafe"),
}
}
}
impl ToFlags for CheckMode {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<String> for CheckMode {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -83,6 +100,19 @@ pub struct Tolerance {
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
@@ -523,14 +553,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
let range_check =
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {

View File

@@ -6,7 +6,6 @@ use crate::{
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
// import run args from model
@@ -69,14 +68,6 @@ pub enum HybridOp {
dim: usize,
num_classes: usize,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
@@ -84,7 +75,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
_ => vec![],
}
@@ -98,162 +88,42 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let (res, intermediate_lookups) = match &self {
HybridOp::ReduceMax { axes, .. } => {
let res = tensor::ops::max_axes(&x, axes)?;
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::ReduceMin { axes, .. } => {
let res = tensor::ops::min_axes(&x, axes)?;
let min_plus_one =
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(min(x + 1) - x)
let inter_1 = (min_plus_one - x.clone())?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
// if denom is a round number and use_range_check_for_int is true, use range check check
if denom.0.fract() == 0.0 && *use_range_check_for_int {
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
(res, vec![-divisor.clone(), divisor])
} else {
(res, vec![x])
}
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
let res = crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
);
// if scale is a round number and use_range_check_for_int is true, use range check check
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
let err_tol = Tensor::from(
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
);
(res, vec![-err_tol.clone(), err_tol])
} else {
(res, vec![x])
}
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let inter =
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
(res, inter)
}
HybridOp::ReduceArgMin { dim } => {
let res = tensor::ops::argmin_axes(&x, *dim)?;
let inter =
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
(res, inter)
}
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather(&x, idx, *dim)?;
(res.clone(), vec![])
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), vec![])
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => (
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone(),
vec![],
),
HybridOp::TopK { dim, k, largest } => {
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
let mut inter_equals = x
.clone()
.into_iter()
.flat_map(|elem| {
tensor::ops::equals(&res, &vec![elem].into_iter().into())
.unwrap()
.1
})
.collect::<Vec<_>>();
// sort in descending order and take pairwise differences
inter_equals.push(
x.into_iter()
.sorted()
.tuple_windows()
.map(|(a, b)| b - a)
.into(),
);
(res.clone(), inter_equals)
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
(res.clone(), vec![])
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), vec![])
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let src = inputs[1].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
(res.clone(), vec![])
} else {
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
let src = inputs[2].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
(res.clone(), vec![])
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => {
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
vec![inter_1, inter_2],
)
}
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
@@ -265,10 +135,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
(
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
vec![],
)
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
@@ -295,10 +162,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output,
intermediate_lookups,
})
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
@@ -352,8 +216,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::TopK { k, dim, largest } => {
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
}
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
}
@@ -426,9 +288,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values.try_into()?,
&LookupOp::Div {
denom: *denom,
},
&LookupOp::Div { denom: *denom },
)?
}
}
@@ -439,26 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,

View File

@@ -128,13 +128,7 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
let mut scaled_unit =
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
region.increment(scaled_unit.len());
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
let is_assigned = !input.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
@@ -166,27 +160,14 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
log::debug!("product: {:?}", product.get_int_evals()?);
// this is now of scale 2 * scale hence why we rescaled the unit scale
let diff_with_input = pairwise(
config,
region,
&[product.clone(), scaled_unit.clone()],
BaseOp::Sub,
)?;
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
// debug print the diff
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
log::debug!("range_check_bracket: {:?}", range_check_bracket);
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
&[product],
&(range_check_bracket, 3 * range_check_bracket),
)?;
Ok(claimed_output)
@@ -233,7 +214,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
let mut assigned_len = 0;
for (i, input) in values.iter_mut().enumerate() {
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_with_duplication(
&config.inputs[i],
@@ -1161,7 +1142,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1230,7 +1211,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1873,7 +1854,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
let shape = &res[0].dims()[2..];
let mut last_elem = res[1..]
.iter()
.fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?;
.try_fold(res[0].clone(), |acc, elem| acc.concat(elem.clone()))?;
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
@@ -2488,7 +2469,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
range: &crate::circuit::table::Range,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_range_check(*range);
region.add_used_range_check(*range)?;
// time the entire operation
let timer = instant::Instant::now();
@@ -2534,7 +2515,7 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
nl: &LookupOp,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_lookup(nl.clone());
region.add_used_lookup(nl.clone(), values)?;
// time the entire operation
let timer = instant::Instant::now();
@@ -2972,8 +2953,17 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
return enforce_equality(config, region, values);
}
let mut values = [values[0].clone(), values[1].clone()];
values[0] = region.assign(&config.inputs[0], &values[0])?;
values[1] = region.assign(&config.inputs[1], &values[1])?;
let total_assigned_0 = values[0].len();
let total_assigned_1 = values[1].len();
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
region.increment(total_assigned);
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(

View File

@@ -227,10 +227,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output })
}
/// Returns the name of the operation
@@ -246,10 +243,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,

View File

@@ -29,7 +29,6 @@ pub mod region;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub(crate) output: Tensor<F>,
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
@@ -178,7 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
intermediate_lookups: vec![],
})
}
@@ -304,10 +302,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {

View File

@@ -1,5 +1,6 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -9,6 +10,14 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -81,6 +90,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -203,14 +214,36 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
}?;
Ok(ForwardResult {
output: res,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output: res })
}
fn layout(
@@ -251,6 +284,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -352,6 +405,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
vec![0, 2]
} else {
vec![]
}

View File

@@ -16,6 +16,8 @@ use std::{
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Region error
@@ -66,6 +68,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
total_constants: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -87,6 +91,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
/// Create a new region context from a wrapped region
@@ -104,6 +110,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -120,6 +128,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -141,6 +151,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants,
used_lookups,
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -180,6 +192,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Create a new region context per loop iteration
/// hacky but it works
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
@@ -190,6 +203,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
@@ -221,6 +236,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
@@ -234,6 +252,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
})?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
@@ -253,19 +276,54 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
Ok(())
}
/// Check if the region is dummy
pub fn is_dummy(&self) -> bool {
self.region.is_none()
}
/// add used lookup
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
self.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) {
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
self.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
/// Get the offset
@@ -293,6 +351,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
self.min_lookup_inputs
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;

View File

@@ -246,7 +246,7 @@ mod matmul_col_overflow {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow_double_col {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -349,8 +349,13 @@ mod matmul_col_ultra_overflow_double_col {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -361,7 +366,7 @@ mod matmul_col_ultra_overflow_double_col {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -463,8 +468,13 @@ mod matmul_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -1140,7 +1150,7 @@ mod conv {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -1262,8 +1272,13 @@ mod conv_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -1275,7 +1290,7 @@ mod conv_col_ultra_overflow {
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -1412,8 +1427,13 @@ mod conv_relu_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -2343,7 +2363,7 @@ mod lookup_ultra_overflow {
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
poly::commitment::ParamsProver,
poly::commitment::{Params, ParamsProver},
};
#[derive(Clone)]
@@ -2447,8 +2467,13 @@ mod lookup_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, ValueEnum};
use clap::{Parser, Subcommand};
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
#[cfg(feature = "python-bindings")]
@@ -9,8 +9,9 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, RunArgs};
@@ -85,15 +86,11 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
impl IntoPy<PyObject> for TranscriptType {
@@ -138,17 +135,27 @@ impl Default for CalibrationTarget {
}
}
impl ToString for CalibrationTarget {
fn to_string(&self) -> String {
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
impl std::fmt::Display for CalibrationTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
)
}
}
impl ToFlags for CalibrationTarget {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
@@ -169,6 +176,36 @@ impl From<&str> for CalibrationTarget {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
inner: H160,
}
#[cfg(not(target_arch = "wasm32"))]
impl From<H160Flag> for H160 {
fn from(val: H160Flag) -> H160 {
val.inner
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ToFlags for H160Flag {
fn to_flags(&self) -> Vec<String> {
vec![format!("{:#x}", self.inner)]
}
}
#[cfg(not(target_arch = "wasm32"))]
impl From<&str> for H160Flag {
fn from(s: &str) -> Self {
Self {
inner: H160::from_str(s).unwrap(),
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
impl IntoPy<PyObject> for CalibrationTarget {
@@ -201,7 +238,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
@@ -242,7 +279,7 @@ impl Cli {
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
#[cfg(feature = "empty-cmd")]
/// Creates an empty buffer
@@ -336,9 +373,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
},
/// Generates a dummy SRS
@@ -509,7 +546,7 @@ pub enum Commands {
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEVMData {
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long)]
data: PathBuf,
@@ -537,7 +574,7 @@ pub enum Commands {
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long)]
addr: H160,
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long)]
data: PathBuf,
@@ -587,9 +624,9 @@ pub enum Commands {
check_mode: CheckMode,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEVMVerifier {
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -612,9 +649,9 @@ pub enum Commands {
render_vk_seperately: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEVMVK {
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -632,9 +669,9 @@ pub enum Commands {
abi_path: PathBuf,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEVMDataAttestation {
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
@@ -654,9 +691,9 @@ pub enum Commands {
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEVMVerifierAggr {
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -695,6 +732,9 @@ pub enum Commands {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
@@ -776,31 +816,23 @@ pub enum Commands {
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Verifies a proof using a local EVM executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEVM {
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
addr_verifier: H160,
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long)]
addr_da: Option<H160>,
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
addr_vk: Option<H160>,
},
/// Print the proof in hexadecimal
#[command(name = "print-proof-hex")]
PrintProofHex {
/// The path to the proof file
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
addr_vk: Option<H160Flag>,
},
}

View File

@@ -3,6 +3,8 @@ use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::commands::Commands;
#[cfg(not(target_arch = "wasm32"))]
use crate::commands::H160Flag;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
#[allow(unused_imports)]
@@ -21,8 +23,6 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::poly::commitment::Params;
@@ -178,7 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
scales,
scale_rebase_multiplier,
max_logrows,
div_rebasing,
only_range_check_rebase,
} => calibrate(
model,
data,
@@ -187,7 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
only_range_check_rebase,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -202,7 +202,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMVerifier {
Commands::CreateEvmVerifier {
vk_path,
srs_path,
settings_path,
@@ -217,7 +217,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
abi_path,
render_vk_seperately,
),
Commands::CreateEVMVK {
Commands::CreateEvmVK {
vk_path,
srs_path,
settings_path,
@@ -225,14 +225,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
abi_path,
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMDataAttestation {
Commands::CreateEvmDataAttestation {
settings_path,
sol_code_path,
abi_path,
data,
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMVerifierAggr {
Commands::CreateEvmVerifierAggr {
vk_path,
srs_path,
sol_code_path,
@@ -270,7 +270,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
compress_selectors,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEVMData {
Commands::SetupTestEvmData {
data,
compiled_circuit,
test_data,
@@ -366,7 +366,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path,
vk_path,
srs_path,
} => verify(proof_path, settings_path, vk_path, srs_path)
reduced_srs,
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::VerifyAggr {
proof_path,
@@ -433,14 +434,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::VerifyEVM {
Commands::VerifyEvm {
proof_path,
addr_verifier,
rpc_url,
addr_da,
addr_vk,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
}
}
@@ -628,7 +628,7 @@ pub(crate) async fn gen_witness(
);
if let Some(output_path) = output {
serde_json::to_writer(&File::create(output_path)?, &witness)?;
witness.save(output_path)?;
}
// print the witness in debug
@@ -780,6 +780,7 @@ impl AccuracyResults {
/// Calibrate the circuit parameters to a given a dataset
#[cfg(not(target_arch = "wasm32"))]
#[allow(trivial_casts)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn calibrate(
model_path: PathBuf,
data: PathBuf,
@@ -788,7 +789,7 @@ pub(crate) fn calibrate(
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
div_rebasing: Option<bool>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -826,14 +827,11 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
match target {
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
}
(10..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
vec![div_rebasing]
let div_rebasing = if only_range_check_rebase {
vec![false]
} else {
vec![true, false]
};
@@ -1172,16 +1170,6 @@ pub(crate) fn mock(
Ok(String::new())
}
pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<String, Box<dyn Error>> {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
for instance in proof.instances {
println!("{:?}", instance);
}
let hex_str = hex::encode(proof.proof);
info!("0x{}", hex_str);
Ok(format!("0x{}", hex_str))
}
#[cfg(feature = "render")]
pub(crate) fn render(
model: PathBuf,
@@ -1400,10 +1388,10 @@ pub(crate) async fn deploy_evm(
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160,
addr_verifier: H160Flag,
rpc_url: Option<String>,
addr_da: Option<H160>,
addr_vk: Option<H160>,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::verify_proof_with_data_attestation;
check_solc_requirement();
@@ -1413,14 +1401,20 @@ pub(crate) async fn verify_evm(
let result = if let Some(addr_da) = addr_da {
verify_proof_with_data_attestation(
proof.clone(),
addr_verifier,
addr_da,
addr_vk,
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
} else {
verify_proof_via_solidity(proof.clone(), addr_verifier, addr_vk, rpc_url.as_deref()).await?
verify_proof_via_solidity(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
};
info!("Solidity verification result: {}", result);
@@ -1576,14 +1570,14 @@ pub(crate) async fn setup_test_evm_witness(
use crate::pfsys::ProofType;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn test_update_account_calls(
addr: H160,
addr: H160Flag,
data: PathBuf,
rpc_url: Option<String>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::update_account_calls;
check_solc_requirement();
update_account_calls(addr, data, rpc_url.as_deref()).await?;
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
Ok(String::new())
}
@@ -1728,6 +1722,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1758,6 +1753,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1794,6 +1790,7 @@ pub(crate) fn fuzz(
proof.clone(),
bad_vk,
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1825,6 +1822,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1860,6 +1858,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -2045,15 +2044,29 @@ pub(crate) fn verify(
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
reduced_srs: bool,
) -> Result<bool, Box<dyn Error>> {
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
let params = if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
};
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
let vk =
load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?;
let now = Instant::now();
let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
let result = verify_proof_circuit_kzg(
params.verifier_params(),
proof,
&vk,
strategy,
1 << circuit_settings.run_args.logrows,
);
let elapsed = now.elapsed();
info!(
"verify took {}.{}",
@@ -2077,7 +2090,7 @@ pub(crate) fn verify_aggr(
let strategy = AccumulatorStrategy::new(params.verifier_params());
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(vk_path, ())?;
let now = Instant::now();
let result = verify_proof_circuit_kzg(&params, proof, &vk, strategy);
let result = verify_proof_circuit_kzg(&params, proof, &vk, strategy, 1 << logrows);
let elapsed = now.elapsed();
info!(

View File

@@ -4,6 +4,7 @@ use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use postgres::{Client, NoTls};
@@ -15,6 +16,8 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
@@ -490,16 +493,20 @@ impl GraphData {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to open input at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let reader = std::fs::File::open(path)?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, self)?;
Ok(())
}
///

View File

@@ -16,6 +16,7 @@ use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSource;
@@ -28,14 +29,17 @@ use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::RunArgs;
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use log::{debug, error, info, trace, warn};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
@@ -46,7 +50,6 @@ use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
@@ -61,6 +64,20 @@ pub const RANGE_MULTIPLIER: i128 = 2;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
/// Max circuit area
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
if let Ok(max_circuit_area) = std::env::var("EZKL_MAX_CIRCUIT_AREA") {
Some(max_circuit_area.parse().unwrap_or(0))
} else {
None
};
}
#[cfg(target_arch = "wasm32")]
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
@@ -191,42 +208,41 @@ impl GraphWitness {
output_scales: Vec<crate::Scale>,
visibility: VarVisibility,
) {
let mut pretty_elements = PrettyElements::default();
pretty_elements.rescaled_inputs = self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.inputs = self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
pretty_elements.rescaled_outputs = self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.outputs = self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
let mut pretty_elements = PrettyElements {
rescaled_inputs: self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
inputs: self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
rescaled_outputs: self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
outputs: self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
..Default::default()
};
if let Some(processed_inputs) = self.processed_inputs.clone() {
pretty_elements.processed_inputs = processed_inputs
@@ -292,16 +308,20 @@ impl GraphWitness {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load model at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
///
@@ -456,22 +476,33 @@ impl GraphSettings {
instances
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>();
// max between 1 and the log2 of the sums
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// save params to file
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
let encoded = serde_json::to_string(&self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(encoded.as_bytes())
// buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// load params from file
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
let mut file = std::fs::File::open(path).map_err(|e| {
error!("failed to open settings file at {}", e);
e
})?;
let mut data = String::new();
file.read_to_string(&mut data)?;
let res = serde_json::from_str(&data)?;
Ok(res)
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// Export the ezkl configuration as json
@@ -530,6 +561,7 @@ impl GraphSettings {
pub struct GraphConfig {
model_config: ModelConfig,
module_configs: ModuleConfigs,
circuit_size: CircuitSize,
}
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
@@ -566,7 +598,7 @@ impl GraphCircuit {
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(f);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
@@ -574,11 +606,10 @@ impl GraphCircuit {
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
// read bytes from file
let mut f = std::fs::File::open(&path)?;
let metadata = std::fs::metadata(&path)?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer)?;
let result = bincode::deserialize(&buffer)?;
let f = std::fs::File::open(path)?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
Ok(result)
}
}
@@ -593,6 +624,17 @@ pub enum TestDataSource {
OnChain,
}
impl std::fmt::Display for TestDataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestDataSource::File => write!(f, "file"),
TestDataSource::OnChain => write!(f, "on-chain"),
}
}
}
impl ToFlags for TestDataSource {}
impl From<String> for TestDataSource {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -809,7 +851,7 @@ impl GraphCircuit {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
info!("input scales: {:?}", scales);
debug!("input scales: {:?}", scales);
match &data.input_data {
DataSource::File(file_data) => {
@@ -828,7 +870,7 @@ impl GraphCircuit {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
info!("input scales: {:?}", scales);
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
@@ -1028,7 +1070,7 @@ impl GraphCircuit {
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
error!("{}", err_string);
debug!("{}", err_string);
return Err(err_string.into());
}
@@ -1048,7 +1090,7 @@ impl GraphCircuit {
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
error!("{}", err_string);
debug!("{}", err_string);
return Err(err_string.into());
}
@@ -1114,7 +1156,7 @@ impl GraphCircuit {
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
info!(
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
self.settings().run_args.lookup_range,
self.settings().run_args.logrows
@@ -1206,7 +1248,7 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs)?;
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1366,7 +1408,6 @@ impl GraphCircuit {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
num_instances: usize,
@@ -1374,20 +1415,22 @@ struct CircuitSize {
num_fixed: usize,
num_challenges: usize,
num_selectors: usize,
logrows: u32,
}
#[cfg(not(target_arch = "wasm32"))]
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>) -> Self {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
num_fixed: cs.num_fixed_columns(),
num_challenges: cs.num_challenges(),
num_selectors: cs.num_selectors(),
logrows,
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
@@ -1398,6 +1441,25 @@ impl CircuitSize {
};
Ok(serialized)
}
/// number of columns
pub fn num_columns(&self) -> usize {
self.num_instances + self.num_advice_columns + self.num_fixed
}
/// area of the circuit
pub fn area(&self) -> usize {
self.num_columns() * (1 << self.logrows)
}
/// area less than max
pub fn area_less_than_max(&self) -> bool {
if EZKL_MAX_CIRCUIT_AREA.is_some() {
self.area() < EZKL_MAX_CIRCUIT_AREA.unwrap()
} else {
true
}
}
}
impl Circuit<Fp> for GraphCircuit {
@@ -1472,10 +1534,12 @@ impl Circuit<Fp> for GraphCircuit {
(cs.degree() as f32).log2().ceil()
);
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"circuit size: \n {}",
CircuitSize::from_cs(cs)
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
@@ -1485,6 +1549,7 @@ impl Circuit<Fp> for GraphCircuit {
GraphConfig {
model_config,
module_configs,
circuit_size,
}
}
@@ -1497,6 +1562,16 @@ impl Circuit<Fp> for GraphCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fp>,
) -> Result<(), PlonkError> {
// check if the circuit area is less than the max
if !config.circuit_size.area_less_than_max() {
error!(
"circuit area {} is larger than the max allowed area {}",
config.circuit_size.area(),
EZKL_MAX_CIRCUIT_AREA.unwrap()
);
return Err(PlonkError::Synthesis);
}
trace!("Setting input in synthesize");
let input_vis = &self.settings().run_args.input_visibility;
let output_vis = &self.settings().run_args.output_visibility;

View File

@@ -10,7 +10,6 @@ use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
@@ -57,6 +56,8 @@ use unzip_n::unzip_n;
unzip_n!(pub 3);
#[cfg(not(target_arch = "wasm32"))]
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
/// The result of a forward pass.
#[derive(Clone, Debug)]
pub struct ForwardResult {
@@ -68,6 +69,16 @@ pub struct ForwardResult {
pub min_lookup_inputs: i128,
}
impl From<DummyPassRes> for ForwardResult {
fn from(res: DummyPassRes) -> Self {
Self {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
}
}
}
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
#[derive(Clone, Debug)]
pub struct ModelConfig {
@@ -93,6 +104,12 @@ pub struct DummyPassRes {
pub lookup_ops: HashSet<LookupOp>,
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i128,
/// min lookup inputs
pub min_lookup_inputs: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
@@ -478,7 +495,7 @@ impl Model {
) -> Result<GraphSettings, Box<dyn Error>> {
let instance_shapes = self.instance_shapes()?;
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {}",
"model has".blue(),
instance_shapes.len().to_string().blue(),
@@ -486,7 +503,25 @@ impl Model {
);
// this is the total number of variables we will need to allocate
// for the circuit
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -521,207 +556,17 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
let mut max_lookup_inputs = 0;
let mut min_lookup_inputs = 0;
let input_shapes = self.graph.input_shapes()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
let mut input = model_inputs[i].clone();
input.reshape(&input_shapes[i])?;
results.insert(input_idx, vec![input]);
}
for (idx, n) in self.graph.nodes.iter() {
let mut inputs = vec![];
if n.is_input() {
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
inputs.push(t);
} else {
for (idx, outlet) in n.inputs().iter() {
match results.get(&idx) {
Some(value) => inputs.push(value[*outlet].clone()),
None => return Err(Box::new(GraphError::MissingNode(*idx))),
}
}
};
debug!("executing {}: {}", idx, n.as_str());
debug!("dims: {:?}", n.out_dims());
debug!(
"input_dims: {:?}",
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);
debug!("input nodes: {:?}", n.inputs());
if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {
max = max.max(
i.iter()
.map(|x| felt_to_i128(*x))
.max()
.ok_or("missing max")?,
);
min = min.min(
i.iter()
.map(|x| felt_to_i128(*x))
.min()
.ok_or("missing min")?,
);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("max lookup inputs: {}", max);
debug!("min lookup inputs: {}", min);
}
match n {
NodeType::Node(n) => {
// execute the op
let start = instant::Instant::now();
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
res.output.reshape(&n.out_dims)?;
let elapsed = start.elapsed();
trace!("op took: {:?}", elapsed);
// see if any of the intermediate lookup calcs are the max
if !res.intermediate_lookups.is_empty() {
let (mut min, mut max) = (0, 0);
for i in &res.intermediate_lookups {
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("intermediate max lookup inputs: {}", max);
debug!("intermediate min lookup inputs: {}", min);
}
debug!(
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
idx,
res.output.map(crate::fieldutils::felt_to_i32).show(),
res.output
.map(|x| crate::fieldutils::felt_to_f64(x)
/ scale_to_multiplier(n.out_scale))
.show(),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
n.out_scale
);
results.insert(idx, vec![res.output]);
}
NodeType::SubGraph {
model,
output_mappings,
input_mappings,
inputs: input_tuple,
..
} => {
let orig_inputs = inputs.clone();
let input_mappings = input_mappings.clone();
let input_dims = inputs.iter().map(|inp| inp.dims());
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, input_tuple, model.graph.inputs
);
debug!("input_mappings: {:?}", input_mappings);
let mut full_results: Vec<Tensor<Fp>> = vec![];
for i in 0..num_iter {
// replace the Stacked input with the current chunk iter
for ((mapping, inp), og_input) in
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
{
if let InputMapping::Stacked { axis, chunk } = mapping {
let start = i * chunk;
let end = (i + 1) * chunk;
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
*inp = t;
}
}
let res = model.forward(&inputs)?;
// recursively get the max lookup inputs for subgraphs
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
let mut outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
for mapping in mappings {
match mapping {
OutputMapping::Single { outlet, .. } => {
outlets.insert(outlet, outlet_res.clone());
}
OutputMapping::Stacked { outlet, axis, .. } => {
if !full_results.is_empty() {
let stacked_res = crate::tensor::ops::concat(
&[&full_results[*outlet], &outlet_res],
*axis,
)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
}
full_results = outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
inputs[*input_idx] = full_results[output_idx].clone();
}
}
trace!(
"------------ output subgraph node {}: {:?}",
idx,
full_results
.iter()
.map(|x|
// convert to tensor i32
x.map(crate::fieldutils::felt_to_i32).show())
.collect_vec()
);
results.insert(idx, full_results);
}
}
}
let output_nodes = self.graph.outputs.iter();
debug!(
"model outputs are nodes: {:?}",
output_nodes.clone().collect_vec()
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
let res = ForwardResult {
outputs,
max_lookup_inputs,
min_lookup_inputs,
};
Ok(res)
pub fn forward(
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(&run_args, &valtensor_inputs)?;
Ok(res.into())
}
/// Loads an Onnx model from a specified path.
@@ -733,7 +578,7 @@ impl Model {
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
) -> Result<TractResult, Box<dyn Error>> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
@@ -772,7 +617,7 @@ impl Model {
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
info!("set {} to {}", symbol, value);
debug!("set {} to {}", symbol, value);
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
@@ -1293,7 +1138,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
@@ -1335,18 +1180,29 @@ impl Model {
};
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants()
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1489,28 +1345,13 @@ impl Model {
pub fn dummy_layout(
&self,
run_args: &RunArgs,
input_shapes: &[Vec<usize>],
inputs: &[ValTensor<Fp>],
) -> Result<DummyPassRes, Box<dyn Error>> {
info!("calculating num of constraints using dummy model layout...");
debug!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = input_shapes
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
results.insert(*input_idx, vec![inputs[i].clone()]);
@@ -1534,27 +1375,26 @@ impl Model {
ValType::Constant(Fp::ONE)
};
let comparator = outputs
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|x| {
let mut v: ValTensor<Fp> =
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
v.reshape(x.dims())?;
Ok(v)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
let _ = outputs
.into_iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o, c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>();
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
@@ -1566,7 +1406,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
@@ -1575,12 +1415,23 @@ impl Model {
region.total_constants().to_string().red()
);
let outputs = outputs
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
let res = DummyPassRes {
num_rows: region.row(),
linear_coord: region.linear_coord(),
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
outputs,
};
Ok(res)

View File

@@ -219,8 +219,6 @@ impl Op<Fp> for RebaseScale {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
res.intermediate_lookups
.extend(rebase_res.intermediate_lookups);
Ok(res)
}
@@ -499,6 +497,7 @@ impl Node {
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
#[cfg(not(target_arch = "wasm32"))]
#[allow(clippy::too_many_arguments)]
pub fn new(
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
other_nodes: &mut BTreeMap<usize, super::NodeType>,

View File

@@ -439,17 +439,16 @@ pub fn new_op_from_onnx(
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
dim: axis,
constant_idx: None,
});
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -478,17 +477,16 @@ pub fn new_op_from_onnx(
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
dim: axis,
constant_idx: None,
});
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})

View File

@@ -1,4 +1,5 @@
use std::error::Error;
use std::fmt::Display;
use crate::tensor::TensorType;
use crate::tensor::{ValTensor, VarTensor};
@@ -14,6 +15,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use super::*;
@@ -40,6 +42,33 @@ pub enum Visibility {
Fixed,
}
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "kzgcommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed {
hash_is_public,
outlets,
} => {
if *hash_is_public {
write!(f, "hashed/public")
} else {
write!(f, "hashed/private/{}", outlets.iter().join(","))
}
}
}
}
}
impl ToFlags for Visibility {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
@@ -202,17 +231,6 @@ impl Visibility {
vec![]
}
}
impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "kzgcommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed { .. } => write!(f, "hashed"),
}
}
}
/// Represents the scale of the model input, model parameters.
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
@@ -410,7 +428,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
num_constants: usize,
module_requires_fixed: bool,
) -> Self {
info!("number of blinding factors: {}", cs.blinding_factors());
debug!("number of blinding factors: {}", cs.blinding_factors());
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))

View File

@@ -7,7 +7,6 @@
overflowing_literals,
path_statements,
patterns_in_fns_without_body,
private_in_public,
unconditional_recursion,
unused,
unused_allocation,
@@ -33,6 +32,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
use graph::Visibility;
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
pub mod circuit;
@@ -71,11 +71,34 @@ pub mod tensor;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
/// The denominator in the fixed point representation used when quantizing inputs
pub type Scale = i32;
#[cfg(not(target_arch = "wasm32"))]
// Buf writer capacity
lazy_static! {
/// The capacity of the buffer used for writing to disk
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(target_arch = "wasm32")]
const EZKL_KEY_FORMAT: &str = "raw-bytes";
#[cfg(target_arch = "wasm32")]
const EZKL_BUF_CAPACITY: &usize = &8000;
/// Parameters specific to a proving run
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[arg(short = 'T', long, default_value = "0")]
@@ -90,7 +113,7 @@ pub struct RunArgs {
#[arg(long, default_value = "1")]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17")]
@@ -99,7 +122,7 @@ pub struct RunArgs {
#[arg(short = 'N', long, default_value = "2")]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, hashed
#[arg(long, default_value = "private")]
@@ -181,34 +204,15 @@ fn parse_key_val<T, U>(
s: &str,
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T: std::str::FromStr + std::fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
U: std::str::FromStr,
U: std::str::FromStr + std::fmt::Debug,
U::Err: std::error::Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
/// Parse a tuple
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr + Clone,
T::Err: std::error::Error + Send + Sync + 'static,
{
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
let res = res
.map(|x| {
// remove blank space
let x = x.trim();
x.parse::<T>()
})
.collect::<Result<Vec<_>, _>>()?;
if res.len() != 2 {
return Err("invalid tuple".into());
}
Ok((res[0].clone(), res[1].clone()))
.find("->")
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
}

View File

@@ -8,6 +8,7 @@ use crate::circuit::CheckMode;
use crate::graph::GraphWitness;
use crate::pfsys::evm::aggregation::PoseidonTranscript;
use crate::tensor::TensorType;
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
@@ -39,9 +40,19 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
use std::ops::Deref;
use std::path::PathBuf;
use thiserror::Error as thisError;
use tosubcommand::ToFlags;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
match s {
"processed" => halo2_proofs::SerdeFormat::Processed,
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
_ => panic!("invalid serde format"),
}
}
#[allow(missing_docs)]
#[derive(
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
@@ -52,6 +63,25 @@ pub enum ProofType {
ForAggr,
}
impl std::fmt::Display for ProofType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ProofType::Single => "single",
ProofType::ForAggr => "for-aggr",
}
)
}
}
impl ToFlags for ProofType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<ProofType> for TranscriptType {
fn from(val: ProofType) -> Self {
match val {
@@ -154,6 +184,21 @@ pub enum TranscriptType {
EVM,
}
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TranscriptType::Poseidon => "poseidon",
TranscriptType::EVM => "evm",
}
)
}
}
impl ToFlags for TranscriptType {}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {
fn to_object(&self, py: Python) -> PyObject {
@@ -315,7 +360,7 @@ where
/// Saves the Proof to a specified `proof_path`.
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
let file = std::fs::File::create(proof_path)?;
let mut writer = BufWriter::new(file);
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(&mut writer, &self)?;
Ok(())
}
@@ -328,8 +373,10 @@ where
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
{
trace!("reading proof");
let data = std::fs::read_to_string(proof_path)?;
serde_json::from_str(&data).map_err(|e| e.into())
let file = std::fs::File::open(proof_path)?;
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let proof: Self = serde_json::from_reader(reader)?;
Ok(proof)
}
}
@@ -555,6 +602,7 @@ where
verifier_params,
pk.get_vk(),
strategy,
verifier_params.n(),
)?;
}
let elapsed = now.elapsed();
@@ -642,6 +690,7 @@ pub fn verify_proof_circuit<
params: &'params Scheme::ParamsVerifier,
vk: &VerifyingKey<Scheme::Curve>,
strategy: Strategy,
orig_n: u64,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
where
Scheme::Scalar: SerdeObject
@@ -662,7 +711,7 @@ where
trace!("instances {:?}", instances);
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
}
/// Loads a [VerifyingKey] at `path`.
@@ -678,13 +727,14 @@ where
info!("loading verification key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
let mut reader = BufReader::new(f);
VerifyingKey::<Scheme::Curve>::read::<_, C>(
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading verification key ✅");
Ok(vk)
}
/// Loads a [ProvingKey] at `path`.
@@ -700,19 +750,20 @@ where
info!("loading proving key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
let mut reader = BufReader::new(f);
ProvingKey::<Scheme::Curve>::read::<_, C>(
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading proving key ✅");
Ok(pk)
}
/// Saves a [ProvingKey] to `path`.
pub fn save_pk<Scheme: CommitmentScheme>(
path: &PathBuf,
vk: &ProvingKey<Scheme::Curve>,
pk: &ProvingKey<Scheme::Curve>,
) -> Result<(), io::Error>
where
Scheme::Curve: SerdeObject + CurveAffine,
@@ -720,9 +771,10 @@ where
{
info!("saving proving key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving proving key ✅");
Ok(())
}
@@ -737,9 +789,10 @@ where
{
info!("saving verification key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving verification key ✅");
Ok(())
}
@@ -750,7 +803,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
) -> Result<(), io::Error> {
info!("saving parameters 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
params.write(&mut writer)?;
writer.flush()?;
Ok(())
@@ -840,6 +893,7 @@ pub(crate) fn verify_proof_circuit_kzg<
proof: Snark<Fr, G1Affine>,
vk: &VerifyingKey<G1Affine>,
strategy: Strategy,
orig_n: u64,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
@@ -849,7 +903,7 @@ pub(crate) fn verify_proof_circuit_kzg<
_,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, params, vk, strategy),
>(&proof, params, vk, strategy, orig_n),
TranscriptType::Poseidon => verify_proof_circuit::<
Fr,
VerifierSHPLONK<'_, Bn256>,
@@ -857,7 +911,7 @@ pub(crate) fn verify_proof_circuit_kzg<
_,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, params, vk, strategy),
>(&proof, params, vk, strategy, orig_n),
}
}

View File

@@ -15,10 +15,9 @@ use crate::graph::{
use crate::pfsys::evm::aggregation::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
Snark, TranscriptType,
TranscriptType,
};
use crate::RunArgs;
use ethers::types::H160;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
@@ -26,7 +25,6 @@ use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use snark_verifier::util::arithmetic::PrimeField;
use std::str::FromStr;
use std::{fs::File, path::PathBuf};
use tokio::runtime::Runtime;
@@ -525,7 +523,7 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
div_rebasing = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
data: PathBuf,
@@ -536,7 +534,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
div_rebasing: Option<bool>,
only_range_check_rebase: bool,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -546,7 +544,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
only_range_check_rebase,
max_logrows,
)
.map_err(|e| {
@@ -689,14 +687,23 @@ fn prove(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_path=PathBuf::from(DEFAULT_VK),
srs_path=None,
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
non_reduced_srs: bool,
) -> Result<bool, PyErr> {
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
crate::execute::verify(
proof_path,
settings_path,
vk_path,
srs_path,
non_reduced_srs,
)
.map_err(|e| {
let err_str = format!("Failed to run verify: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -1022,24 +1029,15 @@ fn verify_evm(
addr_da: Option<&str>,
addr_vk: Option<&str>,
) -> Result<bool, PyErr> {
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_verifier = H160Flag::from(addr_verifier);
let addr_da = if let Some(addr_da) = addr_da {
let addr_da = H160::from_str(addr_da).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_da = H160Flag::from(addr_da);
Some(addr_da)
} else {
None
};
let addr_vk = if let Some(addr_vk) = addr_vk {
let addr_vk = H160::from_str(addr_vk).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_vk = H160Flag::from(addr_vk);
Some(addr_vk)
} else {
None
@@ -1097,16 +1095,6 @@ fn create_evm_verifier_aggr(
Ok(true)
}
/// print hex representation of a proof
#[pyfunction(signature = (proof_path))]
fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
let hex_str = hex::encode(proof.proof);
Ok(format!("0x{}", hex_str))
}
// Python Module
#[pymodule]
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -1145,7 +1133,6 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;

View File

@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
let mut inner = self.inner.clone();
let remainder = self.len() % n;
if remainder != 0 {
inner.resize(self.len() + n - remainder, T::zero().unwrap());
inner.resize(self.len() + n - remainder, pad);
}
Tensor::new(Some(&inner), &[inner.len()])
}

View File

@@ -243,7 +243,7 @@ pub fn and<
/// Some(&[1, 0, 1, 0, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = equals(&a, &b).unwrap().0;
/// let result = equals(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
@@ -260,7 +260,7 @@ pub fn equals<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let a = a.clone();
let b = b.clone();
@@ -268,7 +268,7 @@ pub fn equals<
let result = nonlinearities::kronecker_delta(&diff);
Ok((result, vec![diff]))
Ok(result)
}
/// Greater than operation.
@@ -289,7 +289,7 @@ pub fn equals<
/// ).unwrap();
/// let result = greater(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
pub fn greater<
T: TensorType
@@ -302,7 +302,7 @@ pub fn greater<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -311,7 +311,7 @@ pub fn greater<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok((mask, vec![mask_inter]))
Ok(mask)
}
/// Greater equals than operation.
@@ -332,7 +332,7 @@ pub fn greater<
/// ).unwrap();
/// let result = greater_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
pub fn greater_equal<
T: TensorType
@@ -345,7 +345,7 @@ pub fn greater_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -354,7 +354,7 @@ pub fn greater_equal<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok((mask, vec![mask_inter]))
Ok(mask)
}
/// Less than to operation.
@@ -375,7 +375,7 @@ pub fn greater_equal<
/// ).unwrap();
/// let result = less(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
///
pub fn less<
@@ -389,7 +389,7 @@ pub fn less<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
// a < b <=> b > a
greater(b, a)
}
@@ -412,7 +412,7 @@ pub fn less<
/// ).unwrap();
/// let result = less_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
///
pub fn less_equal<
@@ -426,7 +426,7 @@ pub fn less_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
// a < b <=> b > a
greater_equal(b, a)
}
@@ -2300,12 +2300,12 @@ pub fn deconv<
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap();
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
///
/// // This time with normalization
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
/// ```
@@ -2315,7 +2315,7 @@ pub fn sumpool(
stride: (usize, usize),
kernel_shape: (usize, usize),
normalize: bool,
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
) -> Result<Tensor<i128>, TensorError> {
let image_dims = image.dims();
let batch_size = image_dims[0];
let image_channels = image_dims[1];
@@ -2345,15 +2345,12 @@ pub fn sumpool(
let mut combined = res.combine()?;
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
let mut inter = vec![];
if normalize {
inter.push(combined.clone());
let norm = kernel.len();
combined = nonlinearities::const_div(&combined, norm as f64);
}
Ok((combined, inter))
Ok(combined)
}
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
@@ -3048,11 +3045,7 @@ pub mod nonlinearities {
}
/// softmax layout
pub fn softmax_axes(
a: &Tensor<i128>,
scale: f64,
axes: &[usize],
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
// we want this to be as small as possible so we set the output scale to 1
let dims = a.dims();
@@ -3060,8 +3053,6 @@ pub mod nonlinearities {
return softmax(a, scale);
}
let mut intermediate_values = vec![];
let cartesian_coord = dims[..dims.len() - 1]
.iter()
.map(|x| 0..*x)
@@ -3084,8 +3075,7 @@ pub mod nonlinearities {
let res = softmax(&softmax_input, scale);
outputs.push(res.0);
intermediate_values.extend(res.1);
outputs.push(res);
}
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
@@ -3093,7 +3083,7 @@ pub mod nonlinearities {
.combine()
.unwrap();
res.reshape(dims).unwrap();
(res, intermediate_values)
res
}
/// Applies softmax
@@ -3110,24 +3100,20 @@ pub mod nonlinearities {
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = softmax(&x, 128.0).0;
/// let result = softmax(&x, 128.0);
/// // doubles the scale of the input
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let mut intermediate_values = vec![];
intermediate_values.push(a.clone());
let exp = exp(a, scale);
let sum = sum(&exp).unwrap();
intermediate_values.push(sum.clone());
let inv_denom = recip(&sum, scale, scale);
((exp * inv_denom).unwrap(), intermediate_values)
(exp * inv_denom).unwrap()
}
/// Applies range_check_percent

View File

@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Calls `pad_to_zero_rem` on the inner tensor.
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.pad_to_zero_rem(n)?;
*v = v.pad_to_zero_rem(n, pad)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
@@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
if indices.is_empty() {
return Ok(());
} else {
return Err(TensorError::WrongMethod);
}
}
}
Ok(())

View File

@@ -33,10 +33,7 @@ pub enum VarTensor {
impl VarTensor {
///
pub fn is_advice(&self) -> bool {
match self {
VarTensor::Advice { .. } => true,
_ => false,
}
matches!(self, VarTensor::Advice { .. })
}
///

View File

@@ -311,13 +311,19 @@ pub fn verify(
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
circuit_settings.clone(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
let result = verify_proof_circuit_kzg(
params.verifier_params(),
snark,
&vk,
strategy,
1 << circuit_settings.run_args.logrows,
);
match result {
Ok(_) => Ok(true),
@@ -387,15 +393,6 @@ pub fn prove(
.into_bytes())
}
/// print hex representation of a proof
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let hex_str = hex::encode(proof.proof);
Ok(format!("0x{}", hex_str))
}
// VALIDATION FUNCTIONS
/// Witness file validation

View File

@@ -477,6 +477,7 @@ mod native_tests {
use crate::native_tests::kzg_fuzz;
use crate::native_tests::render_circuit;
use crate::native_tests::model_serialization_different_binaries;
use rand::Rng;
use tempdir::TempDir;
#[test]
@@ -496,7 +497,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
test_dir.close().unwrap();
}
});
@@ -560,7 +561,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
test_dir.close().unwrap();
}
@@ -569,7 +570,18 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_tolerance_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
test_dir.close().unwrap();
}
@@ -580,7 +592,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -589,7 +601,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -598,7 +610,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -607,7 +619,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -616,7 +628,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -625,7 +637,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -634,7 +646,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None);
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -644,7 +656,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -654,7 +666,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -663,7 +675,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -673,7 +685,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None);
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -682,7 +694,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -692,7 +704,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None);
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -702,7 +714,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None);
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -712,7 +724,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -722,7 +734,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -732,7 +744,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
@@ -876,7 +888,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
});
@@ -1273,6 +1285,7 @@ mod native_tests {
batch_size: usize,
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
) {
gen_circuit_settings_and_witness(
test_dir,
@@ -1285,6 +1298,7 @@ mod native_tests {
scales_to_use,
2,
false,
tolerance,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -1312,6 +1326,7 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
div_rebasing: bool,
tolerance: f32,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1321,11 +1336,12 @@ mod native_tests {
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
format!("--variables=batch_size={}", batch_size),
format!("--variables=batch_size->{}", batch_size),
format!("--input-visibility={}", input_visibility),
format!("--param-visibility={}", param_visibility),
format!("--output-visibility={}", output_visibility),
format!("--num-inner-cols={}", num_inner_columns),
format!("--tolerance={}", tolerance),
];
if div_rebasing {
@@ -1402,6 +1418,7 @@ mod native_tests {
}
// Mock prove (fast, but does not cover some potential issues)
#[allow(clippy::too_many_arguments)]
fn accuracy_measurement(
test_dir: &str,
example_name: String,
@@ -1424,6 +1441,7 @@ mod native_tests {
None,
2,
div_rebasing,
0.0,
);
println!(
@@ -1455,7 +1473,7 @@ mod native_tests {
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
"-O",
format!("{}/{}/render.png", test_dir, example_name).as_str(),
"--lookup-range=(-32768,32768)",
"--lookup-range=-32768->32768",
"-K=17",
])
.status()
@@ -1683,6 +1701,7 @@ mod native_tests {
scales_to_use,
num_inner_columns,
false,
0.0,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
@@ -1745,6 +1764,30 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());
// load settings file
let settings =
std::fs::read_to_string(settings_path.clone()).expect("failed to read settings file");
let graph_settings = serde_json::from_str::<GraphSettings>(&settings)
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(graph_settings.log2_total_instances());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"verify",
format!("--settings-path={}", settings_path).as_str(),
"--proof-path",
&format!("{}/{}/proof.pf", test_dir, example_name),
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
"--reduced-srs",
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
// prove-serialize-verify, the usual full path
@@ -1760,6 +1803,7 @@ mod native_tests {
None,
2,
false,
0.0,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
@@ -2036,6 +2080,7 @@ mod native_tests {
Some(vec![4]),
1,
false,
0.0,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);

View File

@@ -131,7 +131,7 @@ mod py_tests {
"simple_demo_aggregated_proofs.ipynb",
"ezkl_demo.ipynb", // 10
"lstm.ipynb",
"set_membership.ipynb",
"set_membership.ipynb", // 12
"decision_tree.ipynb",
"random_forest.ipynb",
"gradient_boosted_trees.ipynb", // 15

View File

@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
@@ -147,13 +147,13 @@ def test_calibrate():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
@@ -183,7 +183,7 @@ def test_model_compile():
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
compiled_model_path = os.path.join(
@@ -205,7 +205,7 @@ def test_forward():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
@@ -392,9 +392,7 @@ def test_prove_evm():
assert res['transcript_type'] == 'EVM'
assert os.path.isfile(proof_path)
res = ezkl.print_proof_hex(proof_path)
# to figure out a better way of testing print_proof_hex
assert type(res) == str
def test_create_evm_verifier():

View File

@@ -9,8 +9,8 @@ mod wasm32 {
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
prove, settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove,
settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
@@ -258,15 +258,6 @@ mod wasm32 {
assert!(value);
}
#[wasm_bindgen_test]
async fn print_proof_hex_test() {
let proof = printProofHex(wasm_bindgen::Clamped(PROOF.to_vec()))
.map_err(|_| "failed")
.unwrap();
assert!(proof.len() > 0);
}
#[wasm_bindgen_test]
async fn verify_validations() {
// Run witness validation on network (should fail)

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.