mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c0c17c9be | ||
|
|
bf69b16fc1 | ||
|
|
74feb829da | ||
|
|
d429e7edab | ||
|
|
f0e5b82787 | ||
|
|
3f7261f50b | ||
|
|
678a249dcb | ||
|
|
0291eb2d0f | ||
|
|
1b637a70b0 | ||
|
|
abcd5380db |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -236,6 +236,8 @@ jobs:
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
|
||||
- name: kzg inputs
|
||||
|
||||
79
Cargo.lock
generated
79
Cargo.lock
generated
@@ -112,7 +112,7 @@ checksum = "c0391754c09fab4eae3404d19d0d297aa1c670c1775ab51d8a5312afeca23157"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -464,7 +464,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -490,7 +490,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -855,7 +855,7 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1491,7 +1491,7 @@ checksum = "48016319042fb7c87b78d2993084a831793a897a5cd1a2a67cab9d1eeb4b7d76"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1657,7 +1657,7 @@ dependencies = [
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
"toml",
|
||||
"walkdir",
|
||||
]
|
||||
@@ -1675,7 +1675,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde_json",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1701,7 +1701,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tiny-keccak",
|
||||
@@ -1902,6 +1902,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tosubcommand",
|
||||
"tract-onnx",
|
||||
"unzip-n",
|
||||
"wasm-bindgen",
|
||||
@@ -2103,7 +2104,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2262,7 +2263,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec 1.0.1",
|
||||
@@ -2279,7 +2280,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
@@ -2975,7 +2976,7 @@ checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3285,7 +3286,7 @@ dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3369,7 +3370,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3570,7 +3571,7 @@ dependencies = [
|
||||
"pest_meta",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3647,7 +3648,7 @@ dependencies = [
|
||||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3685,7 +3686,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3821,7 +3822,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9825a04601d60621feed79c4e6b56d65db77cdca55cef43b46b0de1096d1c282"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4011,7 +4012,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4023,7 +4024,7 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4735,7 +4736,7 @@ checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5044,7 +5045,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5079,9 +5080,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.48"
|
||||
version = "2.0.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
|
||||
checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -5259,7 +5260,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5348,7 +5349,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5444,6 +5445,24 @@ dependencies = [
|
||||
"winnow 0.5.39",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tosubcommand"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/enum_to_subcommand#42e9870f1f757932bab64ab30ebf1ff08a392265"
|
||||
dependencies = [
|
||||
"tosubcommand_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tosubcommand_derive"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/enum_to_subcommand#42e9870f1f757932bab64ab30ebf1ff08a392265"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.2"
|
||||
@@ -5470,7 +5489,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5854,7 +5873,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@@ -5898,7 +5917,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@@ -6309,5 +6328,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"syn 2.0.50",
|
||||
]
|
||||
|
||||
@@ -35,6 +35,8 @@ ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
@@ -159,6 +161,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidi
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
|
||||
@@ -6,6 +6,7 @@ use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -489,6 +490,7 @@ pub fn runconv() {
|
||||
strategy,
|
||||
pi_for_real_prover,
|
||||
&mut transcript,
|
||||
params.n(),
|
||||
);
|
||||
assert!(verify.is_ok());
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ use ezkl::execute::run;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ezkl::logger::init_logger;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{error, info};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(feature = "icicle")]
|
||||
@@ -25,6 +25,7 @@ use std::error::Error;
|
||||
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Cli::parse();
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
@@ -32,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
let res = run(args.command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
@@ -44,7 +45,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn main() {}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
fn banner() {
|
||||
let ell: Vec<&str> = vec![
|
||||
"for Neural Networks",
|
||||
|
||||
@@ -15,7 +15,7 @@ use halo2_proofs::{
|
||||
Instance, Selector, TableColumn,
|
||||
},
|
||||
};
|
||||
use log::{trace, warn};
|
||||
use log::{debug, trace};
|
||||
|
||||
/// A simple [`FloorPlanner`] that performs minimal optimizations.
|
||||
#[derive(Debug)]
|
||||
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
|
||||
Error::Synthesis
|
||||
})?;
|
||||
if !self.regions.contains_key(&index) {
|
||||
warn!("spawning module {}", index)
|
||||
debug!("spawning module {}", index)
|
||||
};
|
||||
self.current_module = index;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::ops::base::BaseOp,
|
||||
@@ -61,6 +62,22 @@ pub enum CheckMode {
|
||||
UNSAFE,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CheckMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CheckMode::SAFE => write!(f, "safe"),
|
||||
CheckMode::UNSAFE => write!(f, "unsafe"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CheckMode {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for CheckMode {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -83,6 +100,19 @@ pub struct Tolerance {
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
@@ -523,14 +553,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
let range_check =
|
||||
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
|
||||
@@ -1854,7 +1854,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
let shape = &res[0].dims()[2..];
|
||||
let mut last_elem = res[1..]
|
||||
.iter()
|
||||
.fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?;
|
||||
.try_fold(res[0].clone(), |acc, elem| acc.concat(elem.clone()))?;
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
@@ -2953,8 +2953,17 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
return enforce_equality(config, region, values);
|
||||
}
|
||||
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
|
||||
values[0] = region.assign(&config.inputs[0], &values[0])?;
|
||||
values[1] = region.assign(&config.inputs[1], &values[1])?;
|
||||
let total_assigned_0 = values[0].len();
|
||||
let total_assigned_1 = values[1].len();
|
||||
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
|
||||
region.increment(total_assigned);
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
|
||||
let recip = nonlinearity(
|
||||
|
||||
@@ -243,10 +243,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Sign => "SIGN".into(),
|
||||
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
|
||||
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
|
||||
LookupOp::LessThan { .. } => "LESS_THAN".into(),
|
||||
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
|
||||
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
|
||||
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
|
||||
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
|
||||
@@ -234,7 +234,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if let Some(_) = constant_idx {
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
|
||||
@@ -203,10 +203,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs =
|
||||
AtomicInt::new(self.max_lookup_inputs().try_into().unwrap_or_default());
|
||||
let min_lookup_inputs =
|
||||
AtomicInt::new(self.min_lookup_inputs().try_into().unwrap_or_default());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
|
||||
@@ -239,13 +237,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
let local_max_lookup_inputs =
|
||||
local_reg.max_lookup_inputs().try_into().unwrap_or_default();
|
||||
let local_min_lookup_inputs =
|
||||
local_reg.min_lookup_inputs().try_into().unwrap_or_default();
|
||||
|
||||
max_lookup_inputs.fetch_max(local_max_lookup_inputs, Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_min_lookup_inputs, Ordering::SeqCst);
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
@@ -261,8 +254,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner() as i128;
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner() as i128;
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
|
||||
@@ -246,7 +246,7 @@ mod matmul_col_overflow {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow_double_col {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -349,8 +349,13 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -361,7 +366,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -463,8 +468,13 @@ mod matmul_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1140,7 +1150,7 @@ mod conv {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1262,8 +1272,13 @@ mod conv_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1275,7 +1290,7 @@ mod conv_col_ultra_overflow {
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1412,8 +1427,13 @@ mod conv_relu_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -2343,7 +2363,7 @@ mod lookup_ultra_overflow {
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
poly::commitment::ParamsProver,
|
||||
poly::commitment::{Params, ParamsProver},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -2447,8 +2467,13 @@ mod lookup_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
|
||||
128
src/commands.rs
128
src/commands.rs
@@ -1,4 +1,4 @@
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use clap::{Parser, Subcommand};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -9,8 +9,9 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::{error::Error, str::FromStr};
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{pfsys::ProofType, RunArgs};
|
||||
|
||||
@@ -85,15 +86,11 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for TranscriptType {
|
||||
@@ -138,17 +135,27 @@ impl Default for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for CalibrationTarget {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
impl std::fmt::Display for CalibrationTarget {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CalibrationTarget {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,6 +176,36 @@ impl From<&str> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
inner: H160,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<H160Flag> for H160 {
|
||||
fn from(val: H160Flag) -> H160 {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl ToFlags for H160Flag {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{:#x}", self.inner)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<&str> for H160Flag {
|
||||
fn from(s: &str) -> Self {
|
||||
Self {
|
||||
inner: H160::from_str(s).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CalibrationTarget {
|
||||
@@ -201,7 +238,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
@@ -242,7 +279,7 @@ impl Cli {
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
#[cfg(feature = "empty-cmd")]
|
||||
/// Creates an empty buffer
|
||||
@@ -336,9 +373,9 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
|
||||
#[arg(long)]
|
||||
div_rebasing: Option<bool>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
|
||||
only_range_check_rebase: bool,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -509,7 +546,7 @@ pub enum Commands {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEVMData {
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -537,7 +574,7 @@ pub enum Commands {
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long)]
|
||||
addr: H160,
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -587,9 +624,9 @@ pub enum Commands {
|
||||
check_mode: CheckMode,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for a single proof
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEVMVerifier {
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -612,9 +649,9 @@ pub enum Commands {
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for a single proof
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEVMVK {
|
||||
CreateEvmVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -632,9 +669,9 @@ pub enum Commands {
|
||||
abi_path: PathBuf,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEVMDataAttestation {
|
||||
CreateEvmDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
@@ -654,9 +691,9 @@ pub enum Commands {
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for an aggregate proof
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEVMVerifierAggr {
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -695,6 +732,9 @@ pub enum Commands {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
},
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
@@ -776,31 +816,23 @@ pub enum Commands {
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Verifies a proof using a local EVM executor, returning accept or reject
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEVM {
|
||||
VerifyEvm {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
addr_verifier: H160,
|
||||
addr_verifier: H160Flag,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long)]
|
||||
addr_da: Option<H160>,
|
||||
addr_da: Option<H160Flag>,
|
||||
// is the vk rendered seperately, if so specify an address
|
||||
#[arg(long)]
|
||||
addr_vk: Option<H160>,
|
||||
},
|
||||
|
||||
/// Print the proof in hexadecimal
|
||||
#[command(name = "print-proof-hex")]
|
||||
PrintProofHex {
|
||||
/// The path to the proof file
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::Commands;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::H160Flag;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(unused_imports)]
|
||||
@@ -21,8 +23,6 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::RunArgs;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -178,7 +178,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
div_rebasing,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -187,7 +187,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -202,7 +202,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifier {
|
||||
Commands::CreateEvmVerifier {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -217,7 +217,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CreateEVMVK {
|
||||
Commands::CreateEvmVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -225,14 +225,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMDataAttestation {
|
||||
Commands::CreateEvmDataAttestation {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
data,
|
||||
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifierAggr {
|
||||
Commands::CreateEvmVerifierAggr {
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
@@ -270,7 +270,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
compress_selectors,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEVMData {
|
||||
Commands::SetupTestEvmData {
|
||||
data,
|
||||
compiled_circuit,
|
||||
test_data,
|
||||
@@ -366,7 +366,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path)
|
||||
reduced_srs,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
proof_path,
|
||||
@@ -433,14 +434,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::VerifyEVM {
|
||||
Commands::VerifyEvm {
|
||||
proof_path,
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
|
||||
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,7 +628,7 @@ pub(crate) async fn gen_witness(
|
||||
);
|
||||
|
||||
if let Some(output_path) = output {
|
||||
serde_json::to_writer(&File::create(output_path)?, &witness)?;
|
||||
witness.save(output_path)?;
|
||||
}
|
||||
|
||||
// print the witness in debug
|
||||
@@ -780,6 +780,7 @@ impl AccuracyResults {
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: PathBuf,
|
||||
@@ -788,7 +789,7 @@ pub(crate) fn calibrate(
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use std::collections::HashMap;
|
||||
@@ -829,8 +830,8 @@ pub(crate) fn calibrate(
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
vec![div_rebasing]
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
@@ -1169,16 +1170,6 @@ pub(crate) fn mock(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
for instance in proof.instances {
|
||||
println!("{:?}", instance);
|
||||
}
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
info!("0x{}", hex_str);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
#[cfg(feature = "render")]
|
||||
pub(crate) fn render(
|
||||
model: PathBuf,
|
||||
@@ -1397,10 +1388,10 @@ pub(crate) async fn deploy_evm(
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
addr_verifier: H160,
|
||||
addr_verifier: H160Flag,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160>,
|
||||
addr_vk: Option<H160>,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
check_solc_requirement();
|
||||
@@ -1410,14 +1401,20 @@ pub(crate) async fn verify_evm(
|
||||
let result = if let Some(addr_da) = addr_da {
|
||||
verify_proof_with_data_attestation(
|
||||
proof.clone(),
|
||||
addr_verifier,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
addr_verifier.into(),
|
||||
addr_da.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
verify_proof_via_solidity(proof.clone(), addr_verifier, addr_vk, rpc_url.as_deref()).await?
|
||||
verify_proof_via_solidity(
|
||||
proof.clone(),
|
||||
addr_verifier.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!("Solidity verification result: {}", result);
|
||||
@@ -1573,14 +1570,14 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
use crate::pfsys::ProofType;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160,
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
check_solc_requirement();
|
||||
update_account_calls(addr, data, rpc_url.as_deref()).await?;
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
@@ -1725,6 +1722,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1755,6 +1753,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1791,6 +1790,7 @@ pub(crate) fn fuzz(
|
||||
proof.clone(),
|
||||
bad_vk,
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1822,6 +1822,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1857,6 +1858,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -2042,15 +2044,29 @@ pub(crate) fn verify(
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
reduced_srs: bool,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
|
||||
|
||||
let params = if reduced_srs {
|
||||
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
|
||||
} else {
|
||||
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
|
||||
};
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
let vk =
|
||||
load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
"verify took {}.{}",
|
||||
@@ -2074,7 +2090,7 @@ pub(crate) fn verify_aggr(
|
||||
let strategy = AccumulatorStrategy::new(params.verifier_params());
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(vk_path, ())?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy, 1 << logrows);
|
||||
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::circuit::InputType;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use postgres::{Client, NoTls};
|
||||
@@ -15,6 +16,8 @@ use pyo3::types::PyDict;
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -490,16 +493,20 @@ impl GraphData {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to open input at {}", path.display()))?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
reader.read_to_string(&mut buf)?;
|
||||
let graph_input = serde_json::from_str(&buf)?;
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
161
src/graph/mod.rs
161
src/graph/mod.rs
@@ -16,6 +16,7 @@ use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSource;
|
||||
@@ -28,7 +29,8 @@ use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::RunArgs;
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
@@ -37,7 +39,7 @@ use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use log::{debug, error, trace, warn};
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
@@ -48,7 +50,6 @@ use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::ops::Deref;
|
||||
use thiserror::Error;
|
||||
pub use utilities::*;
|
||||
@@ -207,42 +208,41 @@ impl GraphWitness {
|
||||
output_scales: Vec<crate::Scale>,
|
||||
visibility: VarVisibility,
|
||||
) {
|
||||
let mut pretty_elements = PrettyElements::default();
|
||||
pretty_elements.rescaled_inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
|
||||
pretty_elements.rescaled_outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
let mut pretty_elements = PrettyElements {
|
||||
rescaled_inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
rescaled_outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(processed_inputs) = self.processed_inputs.clone() {
|
||||
pretty_elements.processed_inputs = processed_inputs
|
||||
@@ -308,16 +308,20 @@ impl GraphWitness {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
let file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load model at {}", path.display()))?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
|
||||
serde_json::to_writer(writer, &self).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
@@ -472,22 +476,33 @@ impl GraphSettings {
|
||||
instances
|
||||
}
|
||||
|
||||
/// calculate the log2 of the total number of instances
|
||||
pub fn log2_total_instances(&self) -> u32 {
|
||||
let sum = self.total_instances().iter().sum::<usize>();
|
||||
|
||||
// max between 1 and the log2 of the sums
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// save params to file
|
||||
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
|
||||
let encoded = serde_json::to_string(&self)?;
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(encoded.as_bytes())
|
||||
// buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, &self).map_err(|e| {
|
||||
error!("failed to save settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
}
|
||||
/// load params from file
|
||||
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
|
||||
let mut file = std::fs::File::open(path).map_err(|e| {
|
||||
error!("failed to open settings file at {}", e);
|
||||
e
|
||||
})?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
let res = serde_json::from_str(&data)?;
|
||||
Ok(res)
|
||||
// buf reader
|
||||
let reader =
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
@@ -583,7 +598,7 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -591,11 +606,10 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
// read bytes from file
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = std::fs::metadata(&path)?;
|
||||
let mut buffer = vec![0; metadata.len() as usize];
|
||||
f.read_exact(&mut buffer)?;
|
||||
let result = bincode::deserialize(&buffer)?;
|
||||
let f = std::fs::File::open(path)?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -610,6 +624,17 @@ pub enum TestDataSource {
|
||||
OnChain,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestDataSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TestDataSource::File => write!(f, "file"),
|
||||
TestDataSource::OnChain => write!(f, "on-chain"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TestDataSource {}
|
||||
|
||||
impl From<String> for TestDataSource {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -826,7 +851,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
info!("input scales: {:?}", scales);
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
match &data.input_data {
|
||||
DataSource::File(file_data) => {
|
||||
@@ -845,7 +870,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
info!("input scales: {:?}", scales);
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
@@ -1045,7 +1070,7 @@ impl GraphCircuit {
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
min_logrows
|
||||
);
|
||||
error!("{}", err_string);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
@@ -1065,7 +1090,7 @@ impl GraphCircuit {
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
);
|
||||
error!("{}", err_string);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
@@ -1131,7 +1156,7 @@ impl GraphCircuit {
|
||||
|
||||
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
|
||||
|
||||
info!(
|
||||
debug!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
self.settings().run_args.lookup_range,
|
||||
self.settings().run_args.logrows
|
||||
@@ -1223,7 +1248,7 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
let mut model_results = self.model().forward(inputs)?;
|
||||
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1512,7 +1537,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"circuit size: \n {}",
|
||||
circuit_size
|
||||
.as_json()
|
||||
|
||||
@@ -56,6 +56,8 @@ use unzip_n::unzip_n;
|
||||
|
||||
unzip_n!(pub 3);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
|
||||
/// The result of a forward pass.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForwardResult {
|
||||
@@ -493,7 +495,7 @@ impl Model {
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
let instance_shapes = self.instance_shapes()?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {}",
|
||||
"model has".blue(),
|
||||
instance_shapes.len().to_string().blue(),
|
||||
@@ -554,12 +556,16 @@ impl Model {
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
|
||||
/// * `run_args` - [RunArgs]
|
||||
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
|
||||
let res = self.dummy_layout(&run_args, &valtensor_inputs)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -572,7 +578,7 @@ impl Model {
|
||||
fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
|
||||
) -> Result<TractResult, Box<dyn Error>> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
@@ -611,7 +617,7 @@ impl Model {
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
let symbol = model.symbol_table.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
info!("set {} to {}", symbol, value);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
}
|
||||
|
||||
// Note: do not optimize the model, as the layout will depend on underlying hardware
|
||||
@@ -1132,7 +1138,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
num_rows.to_string().blue(),
|
||||
@@ -1341,7 +1347,7 @@ impl Model {
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
info!("calculating num of constraints using dummy model layout...");
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
@@ -1369,27 +1375,26 @@ impl Model {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let comparator = outputs
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let mut v: ValTensor<Fp> =
|
||||
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
|
||||
v.reshape(x.dims())?;
|
||||
Ok(v)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let mut comparator: ValTensor<Fp> =
|
||||
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let _ = outputs
|
||||
.iter()
|
||||
.zip(comparator)
|
||||
.map(|(o, c)| {
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[o.clone(), c],
|
||||
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
res?;
|
||||
} else if !self.visibility.output.is_private() {
|
||||
for output in &outputs {
|
||||
region.increment_total_constants(output.num_constants());
|
||||
@@ -1401,7 +1406,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
region.row().to_string().blue(),
|
||||
|
||||
@@ -497,6 +497,7 @@ impl Node {
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
other_nodes: &mut BTreeMap<usize, super::NodeType>,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
use crate::tensor::{ValTensor, VarTensor};
|
||||
@@ -14,6 +15,7 @@ use pyo3::{
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -40,6 +42,33 @@ pub enum Visibility {
|
||||
Fixed,
|
||||
}
|
||||
|
||||
impl Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed {
|
||||
hash_is_public,
|
||||
outlets,
|
||||
} => {
|
||||
if *hash_is_public {
|
||||
write!(f, "hashed/public")
|
||||
} else {
|
||||
write!(f, "hashed/private/{}", outlets.iter().join(","))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Visibility {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
@@ -202,17 +231,6 @@ impl Visibility {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
impl std::fmt::Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed { .. } => write!(f, "hashed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
@@ -410,7 +428,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
num_constants: usize,
|
||||
module_requires_fixed: bool,
|
||||
) -> Self {
|
||||
info!("number of blinding factors: {}", cs.blinding_factors());
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
|
||||
|
||||
63
src/lib.rs
63
src/lib.rs
@@ -32,6 +32,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
|
||||
pub mod circuit;
|
||||
@@ -70,11 +71,34 @@ pub mod tensor;
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub mod wasm;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
pub type Scale = i32;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
// Buf writer capacity
|
||||
lazy_static! {
|
||||
/// The capacity of the buffer used for writing to disk
|
||||
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_KEY_FORMAT: &str = "raw-bytes";
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
#[arg(short = 'T', long, default_value = "0")]
|
||||
@@ -89,7 +113,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1")]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
@@ -98,7 +122,7 @@ pub struct RunArgs {
|
||||
#[arg(short = 'N', long, default_value = "2")]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
@@ -180,34 +204,15 @@ fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr,
|
||||
T: std::str::FromStr + std::fmt::Debug,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
U: std::str::FromStr,
|
||||
U: std::str::FromStr + std::fmt::Debug,
|
||||
U::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let pos = s
|
||||
.find('=')
|
||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
||||
}
|
||||
|
||||
/// Parse a tuple
|
||||
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr + Clone,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
|
||||
|
||||
let res = res
|
||||
.map(|x| {
|
||||
// remove blank space
|
||||
let x = x.trim();
|
||||
x.parse::<T>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if res.len() != 2 {
|
||||
return Err("invalid tuple".into());
|
||||
}
|
||||
Ok((res[0].clone(), res[1].clone()))
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation::PoseidonTranscript;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
@@ -39,9 +40,19 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error as thisError;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
match s {
|
||||
"processed" => halo2_proofs::SerdeFormat::Processed,
|
||||
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
|
||||
_ => panic!("invalid serde format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(
|
||||
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
|
||||
@@ -52,6 +63,25 @@ pub enum ProofType {
|
||||
ForAggr,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProofType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ProofType::Single => "single",
|
||||
ProofType::ForAggr => "for-aggr",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for ProofType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProofType> for TranscriptType {
|
||||
fn from(val: ProofType) -> Self {
|
||||
match val {
|
||||
@@ -154,6 +184,21 @@ pub enum TranscriptType {
|
||||
EVM,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TranscriptType {}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for TranscriptType {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -315,7 +360,7 @@ where
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -328,8 +373,10 @@ where
|
||||
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
|
||||
{
|
||||
trace!("reading proof");
|
||||
let data = std::fs::read_to_string(proof_path)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
let file = std::fs::File::open(proof_path)?;
|
||||
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let proof: Self = serde_json::from_reader(reader)?;
|
||||
Ok(proof)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,6 +602,7 @@ where
|
||||
verifier_params,
|
||||
pk.get_vk(),
|
||||
strategy,
|
||||
verifier_params.n(),
|
||||
)?;
|
||||
}
|
||||
let elapsed = now.elapsed();
|
||||
@@ -642,6 +690,7 @@ pub fn verify_proof_circuit<
|
||||
params: &'params Scheme::ParamsVerifier,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
@@ -662,7 +711,7 @@ where
|
||||
trace!("instances {:?}", instances);
|
||||
|
||||
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
|
||||
}
|
||||
|
||||
/// Loads a [VerifyingKey] at `path`.
|
||||
@@ -678,13 +727,14 @@ where
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
)?;
|
||||
info!("done loading verification key ✅");
|
||||
Ok(vk)
|
||||
}
|
||||
|
||||
/// Loads a [ProvingKey] at `path`.
|
||||
@@ -700,19 +750,20 @@ where
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
)?;
|
||||
info!("done loading proving key ✅");
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
/// Saves a [ProvingKey] to `path`.
|
||||
pub fn save_pk<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
vk: &ProvingKey<Scheme::Curve>,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
) -> Result<(), io::Error>
|
||||
where
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
@@ -720,9 +771,10 @@ where
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
writer.flush()?;
|
||||
info!("done saving proving key ✅");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -737,9 +789,10 @@ where
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
writer.flush()?;
|
||||
info!("done saving verification key ✅");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -750,7 +803,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
params.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
@@ -840,6 +893,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
vk: &VerifyingKey<G1Affine>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
@@ -849,7 +903,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, params, vk, strategy),
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => verify_proof_circuit::<
|
||||
Fr,
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
@@ -857,7 +911,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, params, vk, strategy),
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,9 @@ use crate::graph::{
|
||||
use crate::pfsys::evm::aggregation::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
|
||||
Snark, TranscriptType,
|
||||
TranscriptType,
|
||||
};
|
||||
use crate::RunArgs;
|
||||
use ethers::types::H160;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
@@ -26,7 +25,6 @@ use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
@@ -525,7 +523,7 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
div_rebasing = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
data: PathBuf,
|
||||
@@ -536,7 +534,7 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
div_rebasing: Option<bool>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
@@ -546,7 +544,7 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -689,14 +687,23 @@ fn prove(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
srs_path=None,
|
||||
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
|
||||
))]
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
non_reduced_srs: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
|
||||
crate::execute::verify(
|
||||
proof_path,
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
non_reduced_srs,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -1022,24 +1029,15 @@ fn verify_evm(
|
||||
addr_da: Option<&str>,
|
||||
addr_vk: Option<&str>,
|
||||
) -> Result<bool, PyErr> {
|
||||
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_verifier = H160Flag::from(addr_verifier);
|
||||
let addr_da = if let Some(addr_da) = addr_da {
|
||||
let addr_da = H160::from_str(addr_da).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_da = H160Flag::from(addr_da);
|
||||
Some(addr_da)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let addr_vk = if let Some(addr_vk) = addr_vk {
|
||||
let addr_vk = H160::from_str(addr_vk).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_vk = H160Flag::from(addr_vk);
|
||||
Some(addr_vk)
|
||||
} else {
|
||||
None
|
||||
@@ -1097,16 +1095,6 @@ fn create_evm_verifier_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[pyfunction(signature = (proof_path))]
|
||||
fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
|
||||
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
@@ -1145,7 +1133,6 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
|
||||
@@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
if indices.is_empty() {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -33,10 +33,7 @@ pub enum VarTensor {
|
||||
impl VarTensor {
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
19
src/wasm.rs
19
src/wasm.rs
@@ -311,13 +311,19 @@ pub fn verify(
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
snark,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -387,15 +393,6 @@ pub fn prove(
|
||||
.into_bytes())
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
// VALIDATION FUNCTIONS
|
||||
|
||||
/// Witness file validation
|
||||
|
||||
@@ -477,6 +477,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
use rand::Rng;
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[test]
|
||||
@@ -496,7 +497,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -569,7 +570,18 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_tolerance_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -580,7 +592,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -589,7 +601,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -598,7 +610,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -607,7 +619,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -616,7 +628,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -625,7 +637,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -634,7 +646,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -644,7 +656,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -654,7 +666,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -663,7 +675,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -673,7 +685,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -682,7 +694,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -692,7 +704,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -702,7 +714,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None);
|
||||
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -712,7 +724,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -722,7 +734,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -732,7 +744,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -876,7 +888,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1273,6 +1285,7 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1285,6 +1298,7 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
tolerance,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1312,6 +1326,7 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: f32,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1321,11 +1336,12 @@ mod native_tests {
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size={}", batch_size),
|
||||
format!("--variables=batch_size->{}", batch_size),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
format!("--tolerance={}", tolerance),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
@@ -1402,6 +1418,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn accuracy_measurement(
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
@@ -1424,6 +1441,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
0.0,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1455,7 +1473,7 @@ mod native_tests {
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
"-O",
|
||||
format!("{}/{}/render.png", test_dir, example_name).as_str(),
|
||||
"--lookup-range=(-32768,32768)",
|
||||
"--lookup-range=-32768->32768",
|
||||
"-K=17",
|
||||
])
|
||||
.status()
|
||||
@@ -1683,6 +1701,7 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1745,6 +1764,30 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// load settings file
|
||||
let settings =
|
||||
std::fs::read_to_string(settings_path.clone()).expect("failed to read settings file");
|
||||
|
||||
let graph_settings = serde_json::from_str::<GraphSettings>(&settings)
|
||||
.expect("failed to parse settings file");
|
||||
|
||||
// get_srs for the graph_settings_num_instances
|
||||
download_srs(graph_settings.log2_total_instances());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
"--reduced-srs",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
@@ -1760,6 +1803,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -2036,6 +2080,7 @@ mod native_tests {
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -392,9 +392,7 @@ def test_prove_evm():
|
||||
assert res['transcript_type'] == 'EVM'
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
res = ezkl.print_proof_hex(proof_path)
|
||||
# to figure out a better way of testing print_proof_hex
|
||||
assert type(res) == str
|
||||
|
||||
|
||||
|
||||
def test_create_evm_verifier():
|
||||
|
||||
@@ -9,8 +9,8 @@ mod wasm32 {
|
||||
use ezkl::pfsys;
|
||||
use ezkl::wasm::{
|
||||
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
|
||||
prove, settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove,
|
||||
settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
|
||||
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
@@ -258,15 +258,6 @@ mod wasm32 {
|
||||
assert!(value);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn print_proof_hex_test() {
|
||||
let proof = printProofHex(wasm_bindgen::Clamped(PROOF.to_vec()))
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert!(proof.len() > 0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn verify_validations() {
|
||||
// Run witness validation on network (should fail)
|
||||
|
||||
Reference in New Issue
Block a user