mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e60ba4585 | ||
|
|
d2b683b527 | ||
|
|
a06b09ef1f | ||
|
|
e5aa48fbd6 | ||
|
|
64fbc8a1c9 |
77
.github/workflows/rust.yml
vendored
77
.github/workflows/rust.yml
vendored
@@ -65,40 +65,40 @@ jobs:
|
||||
- name: Library tests (original lookup)
|
||||
run: cargo nextest run --lib --verbose --no-default-features --features ezkl
|
||||
|
||||
ultra-overflow-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Install wasm32-wasi
|
||||
run: rustup target add wasm32-wasi
|
||||
- name: Install cargo-wasi
|
||||
run: cargo install cargo-wasi
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# ultra-overflow-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - uses: mwilliamson/setup-wasmtime-action@v2
|
||||
# with:
|
||||
# wasmtime-version: "3.0.1"
|
||||
# - name: Install wasm32-wasi
|
||||
# run: rustup target add wasm32-wasi
|
||||
# - name: Install cargo-wasi
|
||||
# run: cargo install cargo-wasi
|
||||
# # - name: Matmul overflow (wasi)
|
||||
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# # - name: Conv overflow (wasi)
|
||||
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: lookup overflow
|
||||
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Matmul overflow
|
||||
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv overflow
|
||||
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv + relu overflow
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
runs-on: non-gpu
|
||||
@@ -184,7 +184,6 @@ jobs:
|
||||
|
||||
wasm32-tests:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
@@ -244,7 +243,7 @@ jobs:
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
- name: kzg params
|
||||
@@ -334,8 +333,8 @@ jobs:
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -687,3 +686,5 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
20
Cargo.lock
generated
20
Cargo.lock
generated
@@ -2341,7 +2341,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#3082fda94151fc6760a3cb2be4741ddbeef04c03"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#eede1db7f3e599112bd1186e9d1913286bdcb539"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
@@ -2950,11 +2950,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
dependencies = [
|
||||
"spin 0.5.2",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3532,9 +3532,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
|
||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
@@ -4426,7 +4426,7 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"getrandom",
|
||||
"libc",
|
||||
"spin 0.9.8",
|
||||
"spin",
|
||||
"untrusted",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
@@ -4980,12 +4980,6 @@ dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
|
||||
32
Cargo.toml
32
Cargo.toml
@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde"
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
@@ -33,7 +33,7 @@ log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
|
||||
@@ -47,8 +47,15 @@ semver = "1.0.22"
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
|
||||
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
|
||||
"provider-http",
|
||||
"signers",
|
||||
"contract",
|
||||
"rpc-types-eth",
|
||||
"signer-wallet",
|
||||
"node-bindings",
|
||||
] }
|
||||
foundry-compilers = { version = "0.4.1", features = ["svm-solc"] }
|
||||
ethabi = "18"
|
||||
indicatif = { version = "0.17.5", features = ["rayon"] }
|
||||
gag = { version = "1.0.0", default_features = false }
|
||||
@@ -73,8 +80,8 @@ pyo3 = { version = "0.21.2", features = [
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
|
||||
@@ -162,6 +169,10 @@ harness = false
|
||||
name = "relu"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "relu_lookupless"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu"
|
||||
harness = false
|
||||
@@ -179,7 +190,13 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"]
|
||||
default = [
|
||||
"ezkl",
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -220,4 +237,3 @@ rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
// sets up a new relu table
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&b,
|
||||
&output,
|
||||
&a,
|
||||
BITS,
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
MyConfig { base_config }
|
||||
@@ -75,14 +83,18 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
// sets up a new relu table
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&b,
|
||||
&output,
|
||||
&a,
|
||||
BITS,
|
||||
k,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
MyConfig { base_config }
|
||||
@@ -76,14 +84,18 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
|
||||
.unwrap();
|
||||
|
||||
@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
|
||||
.unwrap();
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
@@ -63,9 +63,13 @@ impl Circuit<Fr> for NLCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
143
benches/relu_lookupless.rs
Normal file
143
benches/relu_lookupless.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NLCircuit {
|
||||
pub input: ValTensor<Fr>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for NLCircuit {
|
||||
type Config = Config<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unsafe {
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
|
||||
) -> Result<(), Error> {
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runrelu(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("relu");
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [4, 8].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
};
|
||||
|
||||
let input: Tensor<Value<Fr>> =
|
||||
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
|
||||
let circuit = NLCircuit {
|
||||
input: ValTensor::from(input.clone()),
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
NLCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runrelu
|
||||
}
|
||||
criterion_main!(benches);
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==14.1.0
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '14.1.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ where
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::ReLU,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -199,7 +199,7 @@ where
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
@@ -221,7 +221,11 @@ where
|
||||
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut x = config
|
||||
|
||||
@@ -69,7 +69,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::ReLU,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -108,7 +108,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(
|
||||
@@ -141,7 +141,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
println!("x shape: {:?}", x.dims());
|
||||
let mut x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("3");
|
||||
@@ -177,7 +181,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
println!("x shape: {:?}", x.dims());
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
println!("6");
|
||||
println!("offset: {}", region.row());
|
||||
|
||||
771
examples/notebooks/ezkl_demo_batch.ipynb
Normal file
771
examples/notebooks/ezkl_demo_batch.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -232,7 +232,7 @@
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n",
|
||||
"run_args.logrows = 15\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
@@ -404,7 +404,7 @@
|
||||
"run_args.output_visibility = \"polycommit\"\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n"
|
||||
"run_args.logrows = 15\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -466,7 +466,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
339
examples/notebooks/reusable_verifier.ipynb
Normal file
339
examples/notebooks/reusable_verifier.ipynb
Normal file
@@ -0,0 +1,339 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reusable Verifiers \n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
|
||||
"\n",
|
||||
"- `1l_mlp sigmoid`\n",
|
||||
"- `1l_mlp relu`\n",
|
||||
"- `1l_conv sigmoid`\n",
|
||||
"- `1l_conv relu`\n",
|
||||
"\n",
|
||||
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
|
||||
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
|
||||
"\n",
|
||||
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we deploy a unique and much smaller verifying key artifact (VKA) contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.onnx\n",
|
||||
"\n",
|
||||
"# Define the models\n",
|
||||
"class MLP_Sigmoid(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLP_Sigmoid, self).__init__()\n",
|
||||
" self.fc = nn.Linear(3, 3)\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc(x)\n",
|
||||
" x = self.sigmoid(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class MLP_Relu(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLP_Relu, self).__init__()\n",
|
||||
" self.fc = nn.Linear(3, 3)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class Conv_Sigmoid(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Conv_Sigmoid, self).__init__()\n",
|
||||
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.conv(x)\n",
|
||||
" x = self.sigmoid(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class Conv_Relu(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Conv_Relu, self).__init__()\n",
|
||||
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.conv(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Instantiate the models\n",
|
||||
"mlp_sigmoid = MLP_Sigmoid()\n",
|
||||
"mlp_relu = MLP_Relu()\n",
|
||||
"conv_sigmoid = Conv_Sigmoid()\n",
|
||||
"conv_relu = Conv_Relu()\n",
|
||||
"\n",
|
||||
"# Dummy input tensor for mlp\n",
|
||||
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
|
||||
"input_mlp_path = 'mlp_input.json'\n",
|
||||
"\n",
|
||||
"# Dummy input tensor for conv\n",
|
||||
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
|
||||
"input_conv_path = 'conv_input.json'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
|
||||
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
|
||||
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
|
||||
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
|
||||
" # Create a new directory for the model if it doesn't exist\n",
|
||||
" if not os.path.exists(name):\n",
|
||||
" os.mkdir(name)\n",
|
||||
" # Store the paths in each of their respective directories\n",
|
||||
" model_path = os.path.join(name, \"network.onnx\")\n",
|
||||
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
|
||||
" pk_path = os.path.join(name, \"test.pk\")\n",
|
||||
" vk_path = os.path.join(name, \"test.vk\")\n",
|
||||
" settings_path = os.path.join(name, \"settings.json\")\n",
|
||||
"\n",
|
||||
" witness_path = os.path.join(name, \"witness.json\")\n",
|
||||
" sol_code_path = os.path.join(name, 'test.sol')\n",
|
||||
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
|
||||
" abi_path = os.path.join(name, 'test.abi')\n",
|
||||
" proof_path = os.path.join(name, \"proof.json\")\n",
|
||||
"\n",
|
||||
" # Flips the neural net into inference mode\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
|
||||
" do_constant_folding=True, input_names=['input'],\n",
|
||||
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
" data = dict(input_data=[data_array])\n",
|
||||
" json.dump(data, open(input_path, 'w'))\n",
|
||||
"\n",
|
||||
" py_run_args = ezkl.PyRunArgs()\n",
|
||||
" py_run_args.input_visibility = \"private\"\n",
|
||||
" py_run_args.output_visibility = \"public\"\n",
|
||||
" py_run_args.param_visibility = \"fixed\" # private by default\n",
|
||||
"\n",
|
||||
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" await ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
|
||||
"\n",
|
||||
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" res = await ezkl.get_srs(settings_path)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" # now generate the witness file\n",
|
||||
" res = await ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
|
||||
" assert os.path.isfile(witness_path) == True\n",
|
||||
"\n",
|
||||
" # SETUP \n",
|
||||
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
|
||||
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
|
||||
" assert res == True\n",
|
||||
" assert os.path.isfile(vk_path)\n",
|
||||
" assert os.path.isfile(pk_path)\n",
|
||||
" assert os.path.isfile(settings_path)\n",
|
||||
"\n",
|
||||
" # GENERATE A PROOF\n",
|
||||
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
|
||||
" assert os.path.isfile(proof_path)\n",
|
||||
"\n",
|
||||
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" res = await ezkl.create_evm_vka(vk_path, settings_path, sol_key_code_path, abi_path)\n",
|
||||
" assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"RPC_URL = \"http://localhost:3030\"\n",
|
||||
"\n",
|
||||
"# Save process globally\n",
|
||||
"anvil_process = None\n",
|
||||
"\n",
|
||||
"def start_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is None:\n",
|
||||
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
|
||||
" if anvil_process.returncode is not None:\n",
|
||||
" raise Exception(\"failed to start anvil process\")\n",
|
||||
" time.sleep(3)\n",
|
||||
"\n",
|
||||
"def stop_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is not None:\n",
|
||||
" anvil_process.terminate()\n",
|
||||
" anvil_process = None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Check that the generated verifiers are identical for all models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_anvil()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import filecmp\n",
|
||||
"\n",
|
||||
"def compare_files(file1, file2):\n",
|
||||
" return filecmp.cmp(file1, file2, shallow=False)\n",
|
||||
"\n",
|
||||
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
|
||||
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
|
||||
"\n",
|
||||
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
|
||||
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
|
||||
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we deploy separate verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os \n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" \"verifier/reusable\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"with open(addr_path_verifier, 'r') as file:\n",
|
||||
" addr = file.read().rstrip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in names:\n",
|
||||
" addr_path_vk = \"addr_vk.txt\"\n",
|
||||
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
|
||||
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" with open(addr_path_vk, 'r') as file:\n",
|
||||
" addr_vk = file.read().rstrip()\n",
|
||||
" \n",
|
||||
" proof_path = os.path.join(name, \"proof.json\")\n",
|
||||
" sol_code_path = os.path.join(name, 'vk.sol')\n",
|
||||
" res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" addr_vk = addr_vk\n",
|
||||
" )\n",
|
||||
" assert res == True"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -9,7 +9,9 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, w, x, y, z):
|
||||
return [((x & y)) == (x & (y | (z ^ w)))]
|
||||
a = (x & y)
|
||||
b = (y & (z ^ w))
|
||||
return [a & b]
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]}
|
||||
{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]}
|
||||
@@ -1,21 +1,17 @@
|
||||
pytorch1.12.1:«
|
||||
+
|
||||
pytorch2.2.2:„
|
||||
*
|
||||
input1
|
||||
input2
|
||||
onnx::Equal_4And_0"And
|
||||
input2
|
||||
/And_output_0/And"And
|
||||
)
|
||||
input3
|
||||
input
|
||||
onnx::Or_5Xor_1"Xor
|
||||
input3
|
||||
input
|
||||
/Xor_output_0/Xor"Xor
|
||||
input2
|
||||
|
||||
onnx::Or_5onnx::And_6Or_2"Or
|
||||
0
|
||||
input1
|
||||
onnx::And_6
|
||||
onnx::Equal_7And_3"And
|
||||
6
|
||||
5
|
||||
input2
|
||||
|
||||
/Xor_output_0/And_1_output_0/And_1"And
|
||||
5
|
||||
|
||||
/And_output_0
|
||||
/And_1_output_0output/And_2"And
|
||||
|
||||
@@ -71,12 +71,17 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
@@ -135,10 +140,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".into(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".into(),
|
||||
HybridOp::Less => "LESS".into(),
|
||||
HybridOp::LessEqual => "LESSEQUAL".into(),
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
HybridOp::Less => "LESS".to_string(),
|
||||
HybridOp::LessEqual => "LESSEQUAL".to_string(),
|
||||
HybridOp::Equals => "EQUALS".into(),
|
||||
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
|
||||
HybridOp::TopK { k, dim, largest } => {
|
||||
|
||||
@@ -240,7 +240,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::layouts::dot;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -374,7 +374,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// // matmul case
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -827,7 +827,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1760,7 +1760,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1864,7 +1864,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2027,7 +2027,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2063,7 +2063,7 @@ pub fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2099,7 +2099,7 @@ pub fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2141,7 +2141,7 @@ pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
@@ -2177,7 +2177,7 @@ pub fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2223,7 +2223,7 @@ pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2415,7 +2415,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2472,7 +2472,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 12, 6, 4, 5, 6]),
|
||||
@@ -2500,12 +2500,9 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
|
||||
|
||||
nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[diff],
|
||||
&LookupOp::GreaterThan { a: utils::F32(0.) },
|
||||
)
|
||||
let sign = sign(config, region, &[diff])?;
|
||||
|
||||
equals(config, region, &[sign, create_unit_tensor(1)])
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -2524,7 +2521,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
///
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -2544,21 +2541,17 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone());
|
||||
let (lhs, rhs) = (values[0].clone(), values[1].clone());
|
||||
|
||||
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
|
||||
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
|
||||
|
||||
nonlinearity(
|
||||
// add 1 to lhs
|
||||
let lhs_plus_one = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[diff],
|
||||
&LookupOp::GreaterThanEqual { a: utils::F32(0.) },
|
||||
)
|
||||
&[lhs.clone(), create_unit_tensor(1)],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
greater(config, region, &[lhs_plus_one, rhs])
|
||||
}
|
||||
|
||||
/// Less than to operation.
|
||||
@@ -2578,7 +2571,7 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2619,7 +2612,7 @@ pub fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2660,7 +2653,7 @@ pub fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2704,7 +2697,7 @@ pub fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2752,7 +2745,7 @@ pub fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2826,7 +2819,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2882,7 +2875,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2925,7 +2918,7 @@ pub fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let mask = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
@@ -2983,7 +2976,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
@@ -3015,7 +3008,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3097,7 +3090,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3198,7 +3191,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3426,7 +3419,7 @@ pub fn deconv<
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -4155,6 +4148,128 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
n: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = values[0].clone();
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..*n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
)
|
||||
.into();
|
||||
|
||||
let cartesian_coord = input
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, input.dims())?;
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
|
||||
let mut claimed_output_slice = if region.witness_gen() {
|
||||
sliced_input.decompose(*base, *n)?
|
||||
} else {
|
||||
Tensor::from(vec![ValType::Value(Value::unknown()); *n + 1].into_iter()).into()
|
||||
};
|
||||
|
||||
claimed_output_slice =
|
||||
region.assign(&config.custom_gates.inputs[1], &claimed_output_slice)?;
|
||||
claimed_output_slice.flatten();
|
||||
|
||||
region.increment(claimed_output_slice.len());
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = claimed_output_slice.first()?;
|
||||
let sign = range_check(config, region, &[sign], &(-1, 1))?;
|
||||
|
||||
// get the rest of the thing and make sure it is in the correct range
|
||||
let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?;
|
||||
|
||||
let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?;
|
||||
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
enforce_equality(config, region, &[sliced_input, signed_decomp])?;
|
||||
|
||||
Ok(claimed_output_slice.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
let mut output_dims = input.dims().to_vec();
|
||||
output_dims.push(*n + 1);
|
||||
combined_output.reshape(&output_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?;
|
||||
// get every n elements now, which correspond to the sign bit
|
||||
|
||||
decomp.get_every_n(region.legs() + 1)?;
|
||||
decomp.reshape(values[0].dims())?;
|
||||
|
||||
Ok(decomp)
|
||||
}
|
||||
|
||||
pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let sign = sign(config, region, values)?;
|
||||
|
||||
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
|
||||
}
|
||||
|
||||
pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let sign = sign(config, region, values)?;
|
||||
|
||||
let mut unit = create_unit_tensor(sign.len());
|
||||
unit.reshape(sign.dims())?;
|
||||
|
||||
let relu_mask = equals(config, region, &[sign, unit])?;
|
||||
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), relu_mask],
|
||||
BaseOp::Mult,
|
||||
)
|
||||
}
|
||||
|
||||
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -4324,7 +4439,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
@@ -4372,7 +4487,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true());
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[100, 200, 300, 400, 500, 600]),
|
||||
|
||||
@@ -15,14 +15,12 @@ use halo2curves::ff::PrimeField;
|
||||
/// An enum representing the operations that can be used to express more complex operations via accumulation
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Abs,
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ReLU,
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
@@ -104,19 +102,6 @@ pub enum LookupOp {
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
GreaterThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
GreaterThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
Sign,
|
||||
KroneckerDelta,
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
@@ -138,7 +123,6 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "abs".into(),
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
@@ -147,18 +131,12 @@ impl LookupOp {
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Sign => "sign".into(),
|
||||
LookupOp::LessThan { a } => format!("less_than_{}", a),
|
||||
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
|
||||
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::ReLU => "relu".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
@@ -188,91 +166,107 @@ impl LookupOp {
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())),
|
||||
LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())),
|
||||
LookupOp::RoundHalfToEven { scale } => Ok(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)),
|
||||
LookupOp::Max { scale, a } => Ok(tensor::ops::nonlinearities::max(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::Min { scale, a } => Ok(tensor::ops::nonlinearities::min(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)),
|
||||
LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::GreaterThanEqual { a } => Ok(
|
||||
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
|
||||
&x,
|
||||
f32::from(*denom).into(),
|
||||
)),
|
||||
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
|
||||
&x,
|
||||
f32::from(*scale).into(),
|
||||
)),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
|
||||
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())),
|
||||
LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())),
|
||||
LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())),
|
||||
LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())),
|
||||
LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())),
|
||||
LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())),
|
||||
LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())),
|
||||
LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())),
|
||||
LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())),
|
||||
LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())),
|
||||
LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())),
|
||||
LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())),
|
||||
LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())),
|
||||
LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())),
|
||||
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
|
||||
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
|
||||
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ceil { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Floor { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Round { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
|
||||
}
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::KroneckerDelta => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
|
||||
}
|
||||
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Rsqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Erf { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Exp { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ACos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::acos(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cosh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cosh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ACosh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::acosh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sin { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sin(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ASin { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::asin(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sinh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sinh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ASinh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::asinh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Tan { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::tan(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ATan { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::atan(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ATanh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::atanh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Tanh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::tanh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| integer_rep_to_felt(x));
|
||||
|
||||
@@ -289,7 +283,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "ABS".into(),
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
@@ -298,11 +291,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Sign => "SIGN".into(),
|
||||
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
|
||||
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
|
||||
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
@@ -313,7 +301,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::ReLU => "RELU".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
@@ -358,12 +345,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::Sign
|
||||
| LookupOp::GreaterThan { .. }
|
||||
| LookupOp::LessThan { .. }
|
||||
| LookupOp::GreaterThanEqual { .. }
|
||||
| LookupOp::LessThanEqual { .. }
|
||||
| LookupOp::KroneckerDelta => 0,
|
||||
LookupOp::KroneckerDelta => 0,
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -9,6 +9,9 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
ReLU,
|
||||
Abs,
|
||||
Sign,
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
@@ -99,8 +102,7 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
,
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
@@ -110,6 +112,9 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::Abs => "ABS".to_string(),
|
||||
PolyOp::Sign => "SIGN".to_string(),
|
||||
PolyOp::ReLU => "RELU".to_string(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
@@ -191,6 +196,9 @@ impl<
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
@@ -368,6 +376,7 @@ impl<
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
|
||||
PolyOp::Sign { .. } => 0,
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -93,6 +93,10 @@ pub struct RegionSettings {
|
||||
pub witness_gen: bool,
|
||||
/// whether we should check range checks for validity
|
||||
pub check_range: bool,
|
||||
/// base for decompositions
|
||||
pub base: usize,
|
||||
/// number of legs for decompositions
|
||||
pub legs: usize,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
@@ -102,26 +106,32 @@ unsafe impl Send for RegionSettings {}
|
||||
|
||||
impl RegionSettings {
|
||||
/// Create a new region settings
|
||||
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
|
||||
pub fn new(witness_gen: bool, check_range: bool, base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen,
|
||||
check_range,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all true
|
||||
pub fn all_true() -> RegionSettings {
|
||||
pub fn all_true(base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: true,
|
||||
check_range: true,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all false
|
||||
pub fn all_false() -> RegionSettings {
|
||||
pub fn all_false(base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: false,
|
||||
check_range: false,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -173,6 +183,16 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
|
||||
/// get the region's decomposition base
|
||||
pub fn base(&self) -> usize {
|
||||
self.settings.base
|
||||
}
|
||||
|
||||
/// get the region's decomposition legs
|
||||
pub fn legs(&self) -> usize {
|
||||
self.settings.legs
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
///
|
||||
pub fn debug_report(&self) {
|
||||
@@ -234,7 +254,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
pub fn new(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = Some(RefCell::new(region));
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -246,7 +272,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -256,9 +282,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
constants: ConstantsMap<F>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols);
|
||||
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::Fr as F;
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
use ops::lookup::LookupOp;
|
||||
use ops::region::RegionCtx;
|
||||
use rand::rngs::OsRng;
|
||||
@@ -55,7 +56,7 @@ mod matmul {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -132,7 +133,7 @@ mod matmul_col_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -206,7 +207,7 @@ mod matmul_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -290,7 +291,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -407,7 +408,7 @@ mod matmul_col_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -518,7 +519,7 @@ mod dot {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -595,7 +596,7 @@ mod dot_col_overflow_triple_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 3);
|
||||
let mut region = RegionCtx::new(region, 0, 3, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -668,7 +669,7 @@ mod dot_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -741,7 +742,7 @@ mod sum {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -811,7 +812,7 @@ mod sum_col_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -880,7 +881,7 @@ mod sum_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -951,7 +952,7 @@ mod composition {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
let _ = config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1042,7 +1043,7 @@ mod conv {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1193,7 +1194,7 @@ mod conv_col_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1297,7 +1298,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1317,15 +1318,23 @@ mod conv_relu_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let mut base_config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// sets up a new relu table
|
||||
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU)
|
||||
.configure_range_check(cs, &a, &b, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (0, 1), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
base_config.clone()
|
||||
}
|
||||
|
||||
@@ -1334,12 +1343,12 @@ mod conv_relu_col_ultra_overflow {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
let output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1355,7 +1364,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(LookupOp::ReLU),
|
||||
Box::new(PolyOp::ReLU),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -1476,7 +1485,7 @@ mod add_w_shape_casting {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1543,7 +1552,7 @@ mod add {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1627,7 +1636,7 @@ mod dynamic_lookup {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::dynamic_lookup(
|
||||
&config,
|
||||
@@ -1769,7 +1778,7 @@ mod shuffle {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::shuffles(
|
||||
&config,
|
||||
@@ -1884,7 +1893,7 @@ mod add_with_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1986,7 +1995,7 @@ mod add_with_overflow_and_poseidon {
|
||||
layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.base
|
||||
.layout(&mut region, &inputs, Box::new(PolyOp::Add))
|
||||
@@ -2092,7 +2101,7 @@ mod sub {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2159,7 +2168,7 @@ mod mult {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2226,7 +2235,7 @@ mod pow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2258,7 +2267,6 @@ mod matmul_relu {
|
||||
|
||||
const K: usize = 18;
|
||||
const LEN: usize = 32;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -2288,11 +2296,17 @@ mod matmul_relu {
|
||||
|
||||
let mut base_config =
|
||||
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// sets up a new relu table
|
||||
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU)
|
||||
.configure_range_check(cs, &a, &b, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
MyConfig { base_config }
|
||||
}
|
||||
|
||||
@@ -2301,11 +2315,14 @@ mod matmul_relu {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.base_config.layout_tables(&mut layouter).unwrap();
|
||||
config
|
||||
.base_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
@@ -2315,7 +2332,7 @@ mod matmul_relu {
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -2354,6 +2371,8 @@ mod relu {
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
const K: u32 = 8;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
@@ -2370,16 +2389,26 @@ mod relu {
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.map(|_| VarTensor::new_advice(cs, 8, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
let mut config = BaseConfig::configure(
|
||||
cs,
|
||||
&[advices[0].clone(), advices[1].clone()],
|
||||
&advices[2],
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
config
|
||||
.configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl)
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K as usize)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (0, 1), K as usize)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K as usize, 8, false);
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
@@ -2388,15 +2417,15 @@ mod relu {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
Ok(config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.unwrap())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@@ -2414,7 +2443,7 @@ mod relu {
|
||||
input: ValTensor::from(input),
|
||||
};
|
||||
|
||||
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
|
||||
let prover = MockProver::run(K, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
@@ -2453,7 +2482,7 @@ mod lookup_ultra_overflow {
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
|
||||
@@ -2481,9 +2510,13 @@ mod lookup_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -141,23 +141,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn f32_eq() {
|
||||
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
|
||||
assert!(F32(std::f32::NAN) != F32(5.0));
|
||||
assert!(F32(5.0) != F32(std::f32::NAN));
|
||||
assert!(F32(f32::NAN) == F32(f32::NAN));
|
||||
assert!(F32(f32::NAN) != F32(5.0));
|
||||
assert!(F32(5.0) != F32(f32::NAN));
|
||||
assert!(F32(0.0) == F32(-0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f32_cmp() {
|
||||
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
|
||||
assert!(F32(std::f32::NAN) < F32(5.0));
|
||||
assert!(F32(5.0) > F32(std::f32::NAN));
|
||||
assert!(F32(f32::NAN) == F32(f32::NAN));
|
||||
assert!(F32(f32::NAN) < F32(5.0));
|
||||
assert!(F32(5.0) > F32(f32::NAN));
|
||||
assert!(F32(0.0) == F32(-0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f32_hash() {
|
||||
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
|
||||
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
|
||||
assert!(calculate_hash(&F32(f32::NAN)) == calculate_hash(&F32(-f32::NAN)));
|
||||
}
|
||||
}
|
||||
|
||||
148
src/commands.rs
148
src/commands.rs
@@ -81,8 +81,10 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
|
||||
/// Default render vk separately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default render reusable verifier
|
||||
pub const DEFAULT_RENDER_REUSABLE: &str = "false";
|
||||
/// Default contract deployment type
|
||||
pub const DEFAULT_CONTRACT_DEPLOYMENT_TYPE: &str = "verifier";
|
||||
/// Default VK sol path
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
@@ -181,6 +183,67 @@ impl From<&str> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// Determines what type of contract (verifier, verifier/reusable, vka) should be deployed
|
||||
pub enum ContractType {
|
||||
/// Deploys a verifier contrat tailored to the circuit and not reusable
|
||||
Verifier {
|
||||
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
reusable: bool,
|
||||
},
|
||||
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
|
||||
VerifyingKeyArtifact,
|
||||
}
|
||||
|
||||
impl Default for ContractType {
|
||||
fn default() -> Self {
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ContractType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
"verifier/reusable".to_string()
|
||||
},
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
} => "verifier".to_string(),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_string(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for ContractType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ContractType {
|
||||
fn from(s: &str) -> Self {
|
||||
match s {
|
||||
"verifier" => ContractType::Verifier { reusable: false },
|
||||
"verifier/reusable" => ContractType::Verifier { reusable: true },
|
||||
"vka" => ContractType::VerifyingKeyArtifact,
|
||||
_ => {
|
||||
log::error!("Invalid value for ContractType");
|
||||
log::warn!("Defaulting to verifier");
|
||||
ContractType::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
@@ -243,6 +306,39 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts ContractType into a PyObject (Required for ContractType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for ContractType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
"verifier/reusable".to_object(py)
|
||||
}
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
} => "verifier".to_object(py),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for ContractType {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
"verifier" => Ok(ContractType::Verifier {
|
||||
reusable: false,
|
||||
}),
|
||||
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
|
||||
"vka" => Ok(ContractType::VerifyingKeyArtifact),
|
||||
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -666,16 +762,14 @@ pub enum Commands {
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEvmVK {
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -739,11 +833,9 @@ pub enum Commands {
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
@@ -785,8 +877,8 @@ pub enum Commands {
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVerifier {
|
||||
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
|
||||
DeployEvm {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
@@ -802,25 +894,9 @@ pub enum Commands {
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
/// Contract type to be deployed
|
||||
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
|
||||
contract: ContractType,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
|
||||
@@ -195,7 +195,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
} => {
|
||||
create_evm_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
@@ -203,7 +203,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
reusable.unwrap_or(DEFAULT_RENDER_REUSABLE.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -219,14 +219,14 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
Commands::CreateEvmVK {
|
||||
Commands::CreateEvmVKArtifact {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => {
|
||||
create_evm_vk(
|
||||
create_evm_vka(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
@@ -260,7 +260,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
} => {
|
||||
create_evm_aggregate_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
@@ -269,7 +269,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
|
||||
aggregation_settings,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
reusable.unwrap_or(DEFAULT_RENDER_REUSABLE.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -434,12 +434,13 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVerifier {
|
||||
Commands::DeployEvm {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
contract,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
@@ -447,25 +448,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVK {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
|
||||
rpc_url,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_VK.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
contract,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -786,7 +769,8 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let region_settings = RegionSettings::all_true();
|
||||
let region_settings =
|
||||
RegionSettings::all_true(settings.run_args.decomp_base, settings.run_args.decomp_legs);
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
@@ -1178,7 +1162,10 @@ pub(crate) async fn calibrate(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(),
|
||||
RegionSettings::all_true(
|
||||
settings.run_args.decomp_base,
|
||||
settings.run_args.decomp_legs,
|
||||
),
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
@@ -1372,8 +1359,10 @@ pub(crate) async fn calibrate(
|
||||
let module_log_row = best_params.module_constraint_logrows_with_blinding();
|
||||
let instance_logrows = best_params.log2_total_instances_with_blinding();
|
||||
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
let range_check_logrows = best_params.range_check_log_rows_with_blinding();
|
||||
|
||||
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
|
||||
reduction = std::cmp::max(reduction, range_check_logrows);
|
||||
reduction = std::cmp::max(reduction, instance_logrows);
|
||||
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
|
||||
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
|
||||
@@ -1426,7 +1415,7 @@ pub(crate) async fn create_evm_verifier(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
@@ -1448,16 +1437,16 @@ pub(crate) async fn create_evm_verifier(
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
let (verifier_solidity, name) = if reusable {
|
||||
(generator.render_separately()?.0, "Halo2VerifierReusable") // ignore the rendered vk artifact for now and generate it in create_evm_vka
|
||||
} else {
|
||||
generator.render()?
|
||||
(generator.render()?, "Halo2Verifier")
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, name, 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
@@ -1465,7 +1454,7 @@ pub(crate) async fn create_evm_verifier(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_vk(
|
||||
pub(crate) async fn create_evm_vka(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
@@ -1498,7 +1487,7 @@ pub(crate) async fn create_evm_vk(
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
@@ -1616,8 +1605,13 @@ pub(crate) async fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
contract: ContractType,
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_name = match contract {
|
||||
ContractType::Verifier { reusable: false } => "Halo2Verifier",
|
||||
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
|
||||
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
|
||||
};
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
@@ -1708,7 +1702,7 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
abi_path: PathBuf,
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> Result<String, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
|
||||
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
|
||||
@@ -1746,8 +1740,8 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
|
||||
generator = generator.set_acc_encoding(Some(acc_encoding));
|
||||
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
let verifier_solidity = if reusable {
|
||||
generator.render_separately()?.0 // ignore the rendered vk artifact for now and generate it in create_evm_vka
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
|
||||
@@ -451,6 +451,18 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// Calc the number of rows required for the range checks
|
||||
pub fn range_check_log_rows_with_blinding(&self) -> u32 {
|
||||
let max_range = self
|
||||
.required_range_checks
|
||||
.iter()
|
||||
.map(|x| x.1 - x.0)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
(max_range as f32).log2().ceil() as u32
|
||||
}
|
||||
|
||||
fn model_constraint_logrows_with_blinding(&self) -> u32 {
|
||||
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
|
||||
@@ -547,7 +547,11 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
|
||||
let res = self.dummy_layout(
|
||||
run_args,
|
||||
&inputs,
|
||||
RegionSettings::all_false(run_args.decomp_base, run_args.decomp_legs),
|
||||
)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -883,16 +887,8 @@ impl Model {
|
||||
);
|
||||
}
|
||||
None => {
|
||||
let mut n = Node::new(
|
||||
n.clone(),
|
||||
&mut nodes,
|
||||
scales,
|
||||
&run_args.param_visibility,
|
||||
i,
|
||||
symbol_values,
|
||||
run_args.div_rebasing,
|
||||
run_args.rebase_frac_zero_constants,
|
||||
)?;
|
||||
let mut n =
|
||||
Node::new(n.clone(), &mut nodes, scales, i, symbol_values, run_args)?;
|
||||
if let Some(ref scales) = override_input_scales {
|
||||
if let Some(inp) = n.opkind.get_input() {
|
||||
let scale = scales[input_idx];
|
||||
@@ -1112,6 +1108,8 @@ impl Model {
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
run_args.decomp_base,
|
||||
run_args.decomp_legs,
|
||||
original_constants.clone(),
|
||||
);
|
||||
// we need to do this as this loop is called multiple times
|
||||
|
||||
@@ -125,6 +125,7 @@ impl RebaseScale {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
&& !inner.is_input()
|
||||
&& !inner.is_identity()
|
||||
{
|
||||
let multiplier =
|
||||
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
|
||||
@@ -326,6 +327,19 @@ impl SupportedOp {
|
||||
SupportedOp::RebaseScale(op) => op,
|
||||
}
|
||||
}
|
||||
|
||||
/// check if is the identity operation
|
||||
/// # Returns
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }),
|
||||
SupportedOp::Rescaled(op) => op.inner.is_identity(),
|
||||
SupportedOp::RebaseScale(op) => op.inner.is_identity(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
@@ -473,11 +487,9 @@ impl Node {
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
other_nodes: &mut BTreeMap<usize, super::NodeType>,
|
||||
scales: &VarScales,
|
||||
param_visibility: &Visibility,
|
||||
idx: usize,
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
rebase_frac_zero_constants: bool,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<Self, GraphError> {
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
@@ -517,11 +529,10 @@ impl Node {
|
||||
let (mut opkind, deleted_indices) = new_op_from_onnx(
|
||||
idx,
|
||||
scales,
|
||||
param_visibility,
|
||||
node.clone(),
|
||||
&mut inputs,
|
||||
symbol_values,
|
||||
rebase_frac_zero_constants,
|
||||
run_args,
|
||||
)?; // parses the op name
|
||||
|
||||
// we can only take the inputs as mutable once -- so we need to collect them first
|
||||
@@ -569,7 +580,7 @@ impl Node {
|
||||
rescale_const_with_single_use(
|
||||
constant,
|
||||
in_scales.clone(),
|
||||
param_visibility,
|
||||
&run_args.param_visibility,
|
||||
input_node.num_uses(),
|
||||
)?;
|
||||
input_node.replace_opkind(constant.clone_dyn().into());
|
||||
@@ -589,7 +600,7 @@ impl Node {
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
div_rebasing,
|
||||
run_args.div_rebasing,
|
||||
);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
@@ -41,7 +41,7 @@ use tract_onnx::tract_hir::{
|
||||
ops::konst::Const,
|
||||
ops::nn::DataFormat,
|
||||
tract_core::ops::cast::Cast,
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, PaddingSpec, SumPool},
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
@@ -94,17 +94,18 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
/// extract padding from a onnx node.
|
||||
pub fn extract_padding(
|
||||
pool_spec: &PoolSpec,
|
||||
num_dims: usize,
|
||||
image_size: &[usize],
|
||||
) -> Result<Vec<(usize, usize)>, GraphError> {
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
PaddingSpec::Valid => vec![(0, 0); num_dims],
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let num_relevant_dims = pool_spec.kernel_shape.len();
|
||||
|
||||
// get the last num_relevant_dims of the image size
|
||||
let image_size = &image_size[image_size.len() - num_relevant_dims..];
|
||||
|
||||
let dims = pool_spec.computed_padding(image_size);
|
||||
let mut padding = Vec::new();
|
||||
for dim in dims {
|
||||
padding.push((dim.pad_before, dim.pad_after));
|
||||
}
|
||||
Ok(padding)
|
||||
}
|
||||
|
||||
@@ -273,11 +274,10 @@ fn load_op<C: tract_onnx::prelude::Op + Clone>(
|
||||
pub fn new_op_from_onnx(
|
||||
idx: usize,
|
||||
scales: &VarScales,
|
||||
param_visibility: &Visibility,
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
inputs: &mut [super::NodeType],
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
@@ -664,13 +664,16 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// if all raw_values are round then set scale to 0
|
||||
let all_round = raw_value.iter().all(|x| (x).fract() == 0.0);
|
||||
if all_round && rebase_frac_zero_constants {
|
||||
if all_round && run_args.rebase_frac_zero_constants {
|
||||
constant_scale = 0;
|
||||
}
|
||||
|
||||
// Quantize the raw value
|
||||
let quantized_value =
|
||||
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
|
||||
let quantized_value = quantize_tensor(
|
||||
raw_value.clone(),
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
@@ -782,7 +785,7 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
if unit == 0. {
|
||||
SupportedOp::Nonlinear(LookupOp::ReLU)
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
@@ -871,7 +874,7 @@ pub fn new_op_from_onnx(
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
|
||||
"Abs" => SupportedOp::Linear(PolyOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
@@ -1014,8 +1017,13 @@ pub fn new_op_from_onnx(
|
||||
if raw_values.log2().fract() == 0.0 {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
// get the non constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
op = SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
|
||||
out_scale: Some(
|
||||
input_scales[non_const_idx] + raw_values.log2() as i32,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1106,7 +1114,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
@@ -1127,7 +1135,7 @@ pub fn new_op_from_onnx(
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Sign" => SupportedOp::Nonlinear(LookupOp::Sign),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
|
||||
@@ -1176,7 +1184,7 @@ pub fn new_op_from_onnx(
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
@@ -1234,7 +1242,7 @@ pub fn new_op_from_onnx(
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1347,7 +1355,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
@@ -1356,11 +1364,6 @@ pub fn new_op_from_onnx(
|
||||
normalized: sumpool_node.normalize,
|
||||
})
|
||||
}
|
||||
// "GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool {
|
||||
// padding: [(0, 0); 2],
|
||||
// stride: (1, 1),
|
||||
// kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]),
|
||||
// }),
|
||||
"Pad" => {
|
||||
let pad_node: &Pad = match node.op().downcast_ref::<Pad>() {
|
||||
Some(b) => b,
|
||||
|
||||
@@ -150,6 +150,7 @@ lazy_static! {
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -277,6 +278,12 @@ pub struct RunArgs {
|
||||
/// commitment scheme
|
||||
#[arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other)]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// the base used for decompositions
|
||||
#[arg(long, default_value = "16384", value_hint = clap::ValueHint::Other)]
|
||||
pub decomp_base: usize,
|
||||
#[arg(long, default_value = "2", value_hint = clap::ValueHint::Other)]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -297,6 +304,8 @@ impl Default for RunArgs {
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,6 +191,12 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// str: commitment type, accepts `kzg`, `ipa`
|
||||
pub commitment: PyCommitments,
|
||||
/// int: The base used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_base: usize,
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -221,6 +227,8 @@ impl From<PyRunArgs> for RunArgs {
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,6 +251,8 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1490,8 +1500,8 @@ fn encode_evm_calldata<'a>(
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
|
||||
/// reusable: bool
|
||||
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1503,7 +1513,7 @@ fn encode_evm_calldata<'a>(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier(
|
||||
py: Python,
|
||||
@@ -1512,7 +1522,7 @@ fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_verifier(
|
||||
@@ -1521,7 +1531,7 @@ fn create_evm_verifier(
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1533,7 +1543,8 @@ fn create_evm_verifier(
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
|
||||
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
@@ -1563,7 +1574,7 @@ fn create_evm_verifier(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None
|
||||
))]
|
||||
fn create_evm_vk(
|
||||
fn create_evm_vka(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1572,7 +1583,7 @@ fn create_evm_vk(
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
@@ -1703,6 +1714,7 @@ fn setup_test_evm_witness(
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
rpc_url=None,
|
||||
contract_type=ContractType::default(),
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
@@ -1711,6 +1723,7 @@ fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
contract_type: ContractType,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
@@ -1721,42 +1734,7 @@ fn deploy_evm(
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// deploys the solidity vk verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_vk_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
contract_type,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1892,8 +1870,8 @@ fn verify_evm<'a>(
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
/// reusable: bool
|
||||
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1906,7 +1884,7 @@ fn verify_evm<'a>(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python,
|
||||
@@ -1916,7 +1894,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path: PathBuf,
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
@@ -1926,7 +1904,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1975,9 +1953,8 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ops::DecompositionError;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
@@ -33,4 +35,7 @@ pub enum TensorError {
|
||||
/// File load error
|
||||
#[error("load error: {0}")]
|
||||
FileLoadError(String),
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
}
|
||||
|
||||
@@ -420,7 +420,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
|
||||
let mut buf_writer = std::io::BufWriter::new(writer);
|
||||
|
||||
self.inner.iter().map(|x| x.clone()).for_each(|x| {
|
||||
self.inner.iter().copied().for_each(|x| {
|
||||
let x = x.to_repr();
|
||||
buf_writer.write_all(x.as_ref()).unwrap();
|
||||
});
|
||||
@@ -764,6 +764,53 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
index
|
||||
}
|
||||
|
||||
/// Fetches every nth element
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5]), &[3]).unwrap();
|
||||
/// assert_eq!(a.get_every_n(2).unwrap(), expected);
|
||||
/// assert_eq!(a.get_every_n(1).unwrap(), a);
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 6]), &[2]).unwrap();
|
||||
/// assert_eq!(a.get_every_n(5).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
if i % n == 0 {
|
||||
inner.push(elem.clone());
|
||||
}
|
||||
}
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
}
|
||||
|
||||
/// Excludes every nth element
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 4, 6]), &[3]).unwrap();
|
||||
/// assert_eq!(a.exclude_every_n(2).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 3, 4, 5]), &[4]).unwrap();
|
||||
/// assert_eq!(a.exclude_every_n(5).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
if i % n != 0 {
|
||||
inner.push(elem.clone());
|
||||
}
|
||||
}
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
}
|
||||
|
||||
/// Duplicates every nth element
|
||||
///
|
||||
/// ```
|
||||
@@ -1217,6 +1264,31 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Get first elem from Tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1]), &[1]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.first().unwrap(), b);
|
||||
/// ```
|
||||
pub fn first(&self) -> Result<Tensor<T>, TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
let res = match self.inner.first() {
|
||||
Some(e) => e.clone(),
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get first element of empty tensor".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
|
||||
@@ -7,6 +7,131 @@ use itertools::Itertools;
|
||||
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
|
||||
pub use std::ops::{Add, Mul, Neg, Sub};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
|
||||
/// Decomposition error
|
||||
pub enum DecompositionError {
|
||||
/// Integer is too large to be represented by base and n
|
||||
#[error("integer {} is too large to be represented by base {} and n {}", .0, .1, .2)]
|
||||
TooLarge(IntegerRep, usize, usize),
|
||||
}
|
||||
|
||||
/// Helper function to get the base decomp of an integer
|
||||
/// # Arguments
|
||||
/// * `x` - IntegerRep
|
||||
/// * `n` - usize
|
||||
/// * `base` - usize
|
||||
///
|
||||
pub fn get_rep(
|
||||
x: &IntegerRep,
|
||||
base: usize,
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
// sign bit
|
||||
rep[0] = if *x < 0 {
|
||||
-1
|
||||
} else if *x > 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut x = x.abs();
|
||||
//
|
||||
for i in (1..rep.len()).rev() {
|
||||
rep[i] = x % base as i128;
|
||||
x /= base as i128;
|
||||
}
|
||||
|
||||
Ok(rep)
|
||||
}
|
||||
|
||||
/// Decompose a tensor of integers into a larger tensor with added dimension of size `n + 1` with the binary (or OTHER base) representation of the integer.
|
||||
/// # Arguments
|
||||
/// * `x` - Tensor
|
||||
/// * `n` - usize
|
||||
/// * `base` - usize
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::decompose;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[0, 1, 2, -1]),
|
||||
/// &[2, 2]).unwrap();
|
||||
///
|
||||
/// let result = decompose(&x, 2, 2).unwrap();
|
||||
/// // result will have dims [2, 2, 3]
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0,
|
||||
/// 1, 0, 1,
|
||||
/// 1, 1, 0,
|
||||
/// -1, 0, 1]), &[2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = decompose(&x, 3, 1).unwrap();
|
||||
///
|
||||
///
|
||||
/// // result will have dims [2, 2, 2]
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0,
|
||||
/// 1, 1,
|
||||
/// 1, 2,
|
||||
/// -1, 1]), &[2, 2, 2]).unwrap();
|
||||
///
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[0, 11, 23, -1]),
|
||||
/// &[2, 2]).unwrap();
|
||||
///
|
||||
/// let result = decompose(&x, 2, 5).unwrap();
|
||||
/// // result will have dims [2, 2, 6]
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0,
|
||||
/// 1, 0, 1, 0, 1, 1,
|
||||
/// 1, 1, 0, 1, 1, 1,
|
||||
/// -1, 0, 0, 0, 0, 1]), &[2, 2, 6]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = decompose(&x, 16, 2).unwrap();
|
||||
/// // result will have dims [2, 2, 3]
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0,
|
||||
/// 1, 0, 11,
|
||||
/// 1, 1, 7,
|
||||
/// -1, 0, 1]), &[2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn decompose(
|
||||
x: &Tensor<IntegerRep>,
|
||||
base: usize,
|
||||
n: usize,
|
||||
) -> Result<Tensor<IntegerRep>, TensorError> {
|
||||
let mut dims = x.dims().to_vec();
|
||||
dims.push(n + 1);
|
||||
|
||||
if n == 0 {
|
||||
let mut x = x.clone();
|
||||
x.reshape(&dims)?;
|
||||
return Ok(x);
|
||||
}
|
||||
|
||||
let resp = x
|
||||
.par_iter()
|
||||
.map(|val| get_rep(val, base, n))
|
||||
// now collect the results into a Result<Vec<Vec<IntegerRep>>, DecompositionError>
|
||||
.collect::<Result<Vec<Vec<IntegerRep>>, DecompositionError>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<IntegerRep>>();
|
||||
|
||||
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Trilu operation.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
@@ -429,7 +554,7 @@ pub fn downsample<T: TensorType + Send + Sync>(
|
||||
|
||||
output = output.par_enum_map(|i, _: T| {
|
||||
let coord = indices[i].clone();
|
||||
Ok(input.get(&coord))
|
||||
Ok::<_, TensorError>(input.get(&coord))
|
||||
})?;
|
||||
|
||||
Ok(output)
|
||||
@@ -489,7 +614,7 @@ pub fn gather<T: TensorType + Send + Sync>(
|
||||
.map(|(i, x)| if i == dim { index_val } else { *x })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(input.get(&new_coord))
|
||||
Ok::<_, TensorError>(input.get(&new_coord))
|
||||
})?;
|
||||
|
||||
// Reshape the output tensor
|
||||
@@ -613,7 +738,7 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
|
||||
|
||||
let val = input.get(&new_coord);
|
||||
|
||||
Ok(val)
|
||||
Ok::<_, TensorError>(val)
|
||||
})?;
|
||||
|
||||
// Reshape the output tensor
|
||||
@@ -927,7 +1052,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok(())
|
||||
Ok::<_, TensorError>(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
@@ -1234,7 +1359,7 @@ pub fn concat<T: TensorType + Send + Sync>(
|
||||
index += x;
|
||||
}
|
||||
|
||||
Ok(inputs[input_index].get(&input_coord))
|
||||
Ok::<_, TensorError>(inputs[input_index].get(&input_coord))
|
||||
})?;
|
||||
|
||||
// Reshape the output tensor
|
||||
|
||||
@@ -520,6 +520,54 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the sign of the inner values
|
||||
pub fn sign(&self) -> Result<Self, TensorError> {
|
||||
let evals = self.int_evals()?;
|
||||
Ok(evals
|
||||
.par_enum_map(|_, val| {
|
||||
Ok::<_, TensorError>(ValType::Value(Value::known(integer_rep_to_felt(
|
||||
val.signum(),
|
||||
))))
|
||||
})?
|
||||
.into())
|
||||
}
|
||||
|
||||
/// Decompose the inner values into base `base` and `n` legs.
|
||||
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
|
||||
let res = self
|
||||
.get_inner()?
|
||||
.par_iter()
|
||||
.map(|x| {
|
||||
let mut is_empty = true;
|
||||
x.map(|_| is_empty = false);
|
||||
if is_empty {
|
||||
return Ok::<_, TensorError>(vec![Value::<F>::unknown(); n + 1]);
|
||||
} else {
|
||||
let mut res = vec![Value::unknown(); n + 1];
|
||||
let mut int_rep = 0;
|
||||
|
||||
x.map(|f| {
|
||||
int_rep = crate::fieldutils::felt_to_integer_rep(f);
|
||||
});
|
||||
let decompe = crate::tensor::ops::get_rep(&int_rep, base, n)?;
|
||||
|
||||
for (i, x) in decompe.iter().enumerate() {
|
||||
res[i] = Value::known(crate::fieldutils::integer_rep_to_felt(*x));
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims.push(n + 1);
|
||||
|
||||
tensor.reshape(&dims)?;
|
||||
|
||||
Ok(tensor.into())
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
@@ -574,7 +622,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
/// Calls `last` on the inner tensor.
|
||||
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
@@ -595,6 +643,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `first`
|
||||
pub fn first(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
inner: v,
|
||||
dims: _,
|
||||
scale,
|
||||
} => {
|
||||
let inner = v.first()?;
|
||||
let dims = inner.dims().to_vec();
|
||||
ValTensor::Value {
|
||||
inner,
|
||||
dims,
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
|
||||
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
|
||||
@@ -775,6 +844,38 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_every_n` on the inner [Tensor].
|
||||
pub fn get_every_n(&mut self, n: usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.get_every_n(n)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `exclude_every_n` on the inner [Tensor].
|
||||
pub fn exclude_every_n(&mut self, n: usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.exclude_every_n(n)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// remove constant zero values constants
|
||||
pub fn remove_const_zero_values(&mut self) {
|
||||
match self {
|
||||
|
||||
10
src/wasm.rs
10
src/wasm.rs
@@ -286,7 +286,15 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(
|
||||
circuit.settings().run_args.decomp_base,
|
||||
circuit.settings().run_args.decomp_legs,
|
||||
),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -7,11 +7,15 @@ mod native_tests {
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::Bn256;
|
||||
use lazy_static::lazy_static;
|
||||
use rand::Rng;
|
||||
use std::env::var;
|
||||
use std::io::{Read, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command};
|
||||
use std::sync::Once;
|
||||
static COMPILE: Once = Once::new();
|
||||
@@ -241,8 +245,8 @@ mod native_tests {
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"idolmodel",
|
||||
"trig",
|
||||
"idolmodel", // too big evm
|
||||
"trig", // too big evm
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
@@ -983,16 +987,26 @@ mod native_tests {
|
||||
mod tests_evm {
|
||||
use seq_macro::seq;
|
||||
use crate::native_tests::TESTS_EVM;
|
||||
use crate::native_tests::TESTS;
|
||||
use crate::native_tests::TESTS_EVM_AGGR;
|
||||
use test_case::test_case;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify_render_seperately;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify_reusable_verifier;
|
||||
|
||||
use crate::native_tests::kzg_evm_on_chain_input_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_aggr_prove_and_verify;
|
||||
use tempdir::TempDir;
|
||||
use crate::native_tests::Hardfork;
|
||||
use crate::native_tests::run_js_tests;
|
||||
use ezkl::logger::init_logger;
|
||||
use crate::native_tests::lazy_static;
|
||||
|
||||
// Global variables to store verifier hashes and identical verifiers
|
||||
lazy_static! {
|
||||
// create a new variable of type
|
||||
static ref REUSABLE_VERIFIER_ADDR: std::sync::Mutex<Option<String>> = std::sync::Mutex::new(None);
|
||||
}
|
||||
|
||||
|
||||
/// Currently only on chain inputs that return a non-negative value are supported.
|
||||
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
|
||||
@@ -1104,6 +1118,70 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_ for test: {}", test);
|
||||
// default vis
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), false);
|
||||
// public/public vis
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
// hashed input
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_with_overflow_(test: &str) {
|
||||
// verifier too big to fit on chain with overflow calibration target
|
||||
if test == "1l_eltwise_div" || test == "lenet_5" || test == "ltsf" || test == "lstm_large" {
|
||||
return;
|
||||
}
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_with_overflow_ for test: {}", test);
|
||||
// default vis
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), true);
|
||||
// public/public vis
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
// hashed input
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
seq!(N in 0..=22 {
|
||||
|
||||
@@ -1120,19 +1198,6 @@ mod native_tests {
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_prove_and_verify_render_seperately_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", true);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_hashed_input_prove_and_verify_(test: &str) {
|
||||
@@ -1883,7 +1948,7 @@ mod native_tests {
|
||||
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm-verifier",
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
"--sol-code-path",
|
||||
@@ -2114,11 +2179,7 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier
|
||||
let mut args = vec![
|
||||
"deploy-evm-verifier",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
];
|
||||
let mut args = vec!["deploy-evm", rpc_arg.as_str(), addr_path_arg.as_str()];
|
||||
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
@@ -2160,14 +2221,16 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
fn kzg_evm_prove_and_verify_render_seperately(
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
input_visibility: &str,
|
||||
param_visibility: &str,
|
||||
output_visibility: &str,
|
||||
) {
|
||||
reusable_verifier_address: &mut Option<String>,
|
||||
overflow: bool,
|
||||
) -> String {
|
||||
let anvil_url = ANVIL_URL.as_str();
|
||||
|
||||
prove_and_verify(
|
||||
@@ -2179,7 +2242,7 @@ mod native_tests {
|
||||
output_visibility,
|
||||
num_inner_columns,
|
||||
None,
|
||||
false,
|
||||
overflow,
|
||||
"single",
|
||||
Commitments::KZG,
|
||||
2,
|
||||
@@ -2194,27 +2257,58 @@ mod native_tests {
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
let sol_arg = format!("--sol-code-path={}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--render-vk-seperately",
|
||||
];
|
||||
// if the reusable verifier address is not set, create the verifier
|
||||
let deployed_addr_arg = match reusable_verifier_address {
|
||||
Some(addr) => addr.clone(),
|
||||
None => {
|
||||
// create the reusable verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr =
|
||||
std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
// set the reusable verifier address
|
||||
*reusable_verifier_address = Some(addr);
|
||||
deployed_addr_arg
|
||||
}
|
||||
};
|
||||
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
let sol_arg_vk: String = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-vk",
|
||||
"create-evm-vka",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
@@ -2227,32 +2321,13 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier
|
||||
// deploy the vka
|
||||
let args = vec![
|
||||
"deploy-evm-verifier",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr = std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
|
||||
// deploy the vk
|
||||
let args = vec![
|
||||
"deploy-evm-vk",
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg_vk.as_str(),
|
||||
sol_arg_vk.as_str(),
|
||||
"-C=vka",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -2282,7 +2357,7 @@ mod native_tests {
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let mut args = vec![
|
||||
let args = vec![
|
||||
"verify-evm",
|
||||
"--proof-path",
|
||||
pf_arg.as_str(),
|
||||
@@ -2296,13 +2371,52 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
args[2] = PF_FAILURE;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
// Read the original proof file
|
||||
let original_proof_data: ezkl::pfsys::Snark<
|
||||
halo2curves::bn256::Fr,
|
||||
halo2curves::bn256::G1Affine,
|
||||
> = Snark::load::<KZGCommitmentScheme<Bn256>>(&PathBuf::from(format!(
|
||||
"{}/{}/proof.pf",
|
||||
test_dir, example_name
|
||||
)))
|
||||
.expect("Failed to read proof file");
|
||||
|
||||
for i in 0..1 {
|
||||
// Create a copy of the original proof data
|
||||
let mut modified_proof_data = original_proof_data.clone();
|
||||
|
||||
// Flip a random bit
|
||||
let random_byte = rand::thread_rng().gen_range(0..modified_proof_data.proof.len());
|
||||
let random_bit = rand::thread_rng().gen_range(0..8);
|
||||
modified_proof_data.proof[random_byte] ^= 1 << random_bit;
|
||||
|
||||
// Write the modified proof to a new file
|
||||
let modified_pf_arg = format!("{}/{}/modified_proof_{}.pf", test_dir, example_name, i);
|
||||
modified_proof_data
|
||||
.save(&PathBuf::from(modified_pf_arg.clone()))
|
||||
.expect("Failed to save modified proof file");
|
||||
|
||||
// Verify the modified proof (should fail)
|
||||
let mut args_mod = args.clone();
|
||||
args_mod[2] = &modified_pf_arg;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args_mod)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
if status.success() {
|
||||
log::error!("Verification unexpectedly succeeded for modified proof {}. Flipped bit {} in byte {}", i, random_bit, random_byte);
|
||||
}
|
||||
|
||||
assert!(
|
||||
!status.success(),
|
||||
"Modified proof {} should have failed verification",
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
// Returned deploy_addr_arg for reusable verifier
|
||||
deployed_addr_arg
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
@@ -2504,7 +2618,7 @@ mod native_tests {
|
||||
|
||||
// deploy the verifier
|
||||
let mut args = vec![
|
||||
"deploy-evm-verifier",
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_verifier_arg.as_str(),
|
||||
];
|
||||
|
||||
@@ -11,7 +11,7 @@ mod py_tests {
|
||||
static ENV_SETUP: Once = Once::new();
|
||||
static DOWNLOAD_VOICE_DATA: Once = Once::new();
|
||||
|
||||
//Sure to run this once
|
||||
// Sure to run this once
|
||||
|
||||
lazy_static! {
|
||||
static ref CARGO_TARGET_DIR: String =
|
||||
@@ -123,7 +123,8 @@ mod py_tests {
|
||||
}
|
||||
}
|
||||
|
||||
const TESTS: [&str; 33] = [
|
||||
const TESTS: [&str; 34] = [
|
||||
"ezkl_demo_batch.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
@@ -201,6 +202,18 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reusable_verifier_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
let mut anvil_child = crate::py_tests::start_anvil(false);
|
||||
let test_dir: TempDir = TempDir::new("reusable_verifier").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "reusable_verifier.ipynb");
|
||||
run_notebook(path, "reusable_verifier.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn postgres_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
@@ -23,7 +23,7 @@ examples_path = os.path.abspath(
|
||||
|
||||
srs_path = os.path.join(folder_path, 'kzg_test.params')
|
||||
params_k17_path = os.path.join(folder_path, 'kzg_test_k17.params')
|
||||
params_k20_path = os.path.join(folder_path, 'kzg_test_k20.params')
|
||||
params_k21_path = os.path.join(folder_path, 'kzg_test_k21.params')
|
||||
anvil_url = "http://localhost:3030"
|
||||
|
||||
|
||||
@@ -104,8 +104,8 @@ def test_gen_srs():
|
||||
ezkl.gen_srs(params_k17_path, 17)
|
||||
assert os.path.isfile(params_k17_path)
|
||||
|
||||
ezkl.gen_srs(params_k20_path, 20)
|
||||
assert os.path.isfile(params_k20_path)
|
||||
ezkl.gen_srs(params_k21_path, 21)
|
||||
assert os.path.isfile(params_k21_path)
|
||||
|
||||
|
||||
|
||||
@@ -450,10 +450,10 @@ async def test_create_evm_verifier_separate_vk():
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
srs_path=srs_path,
|
||||
render_vk_seperately=True
|
||||
reusable=True
|
||||
)
|
||||
|
||||
res = await ezkl.create_evm_vk(
|
||||
res = await ezkl.create_evm_vka(
|
||||
vk_path,
|
||||
settings_path,
|
||||
vk_code_path,
|
||||
@@ -465,9 +465,9 @@ async def test_create_evm_verifier_separate_vk():
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
|
||||
async def test_deploy_evm_separate_vk():
|
||||
async def test_deploy_evm_reusable_and_vka():
|
||||
"""
|
||||
Test deployment of the separate verifier smart contract + vk
|
||||
Test deployment of the reusable verifier smart contract + vka
|
||||
In order to run this you will need to install solc in your environment
|
||||
"""
|
||||
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
|
||||
@@ -481,13 +481,15 @@ async def test_deploy_evm_separate_vk():
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_verifier,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
anvil_url,
|
||||
"verifier/reusable",
|
||||
)
|
||||
|
||||
res = await ezkl.deploy_vk_evm(
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_vk,
|
||||
vk_code_path,
|
||||
rpc_url=anvil_url,
|
||||
anvil_url,
|
||||
"vka",
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -506,7 +508,7 @@ async def test_deploy_evm():
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
anvil_url,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -677,7 +679,7 @@ async def test_aggregate_and_verify_aggr():
|
||||
)
|
||||
|
||||
# mock aggregate
|
||||
res = ezkl.mock_aggregate([proof_path], 20)
|
||||
res = ezkl.mock_aggregate([proof_path], 21)
|
||||
assert res == True
|
||||
|
||||
aggregate_proof_path = os.path.join(folder_path, 'aggr_1l_relu.pf')
|
||||
@@ -688,8 +690,8 @@ async def test_aggregate_and_verify_aggr():
|
||||
[proof_path],
|
||||
aggregate_vk_path,
|
||||
aggregate_pk_path,
|
||||
20,
|
||||
srs_path=params_k20_path,
|
||||
21,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
|
||||
res = ezkl.gen_vk_from_pk_aggr(aggregate_pk_path, aggregate_vk_path)
|
||||
@@ -701,9 +703,9 @@ async def test_aggregate_and_verify_aggr():
|
||||
aggregate_proof_path,
|
||||
aggregate_pk_path,
|
||||
"poseidon",
|
||||
20,
|
||||
21,
|
||||
"unsafe",
|
||||
srs_path=params_k20_path,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -713,8 +715,8 @@ async def test_aggregate_and_verify_aggr():
|
||||
res = ezkl.verify_aggr(
|
||||
aggregate_proof_path,
|
||||
aggregate_vk_path,
|
||||
20,
|
||||
srs_path=params_k20_path,
|
||||
21,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
assert res == True
|
||||
|
||||
@@ -793,8 +795,8 @@ async def test_evm_aggregate_and_verify_aggr():
|
||||
[proof_path],
|
||||
aggregate_vk_path,
|
||||
aggregate_pk_path,
|
||||
20,
|
||||
srs_path=params_k20_path,
|
||||
21,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
|
||||
res = ezkl.aggregate(
|
||||
@@ -802,9 +804,9 @@ async def test_evm_aggregate_and_verify_aggr():
|
||||
aggregate_proof_path,
|
||||
aggregate_pk_path,
|
||||
"evm",
|
||||
20,
|
||||
21,
|
||||
"unsafe",
|
||||
srs_path=params_k20_path,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -819,8 +821,8 @@ async def test_evm_aggregate_and_verify_aggr():
|
||||
aggregate_vk_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
logrows=20,
|
||||
srs_path=params_k20_path,
|
||||
logrows=21,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -838,8 +840,8 @@ async def test_evm_aggregate_and_verify_aggr():
|
||||
res = ezkl.verify_aggr(
|
||||
aggregate_proof_path,
|
||||
aggregate_vk_path,
|
||||
20,
|
||||
srs_path=params_k20_path,
|
||||
21,
|
||||
srs_path=params_k21_path,
|
||||
)
|
||||
assert res == True
|
||||
|
||||
|
||||
@@ -16,14 +16,12 @@ mod wasm32 {
|
||||
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::Bn256;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use wasm_bindgen::JsError;
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
use wasm_bindgen_test::*;
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -8,7 +8,7 @@
|
||||
"param_scale": 0,
|
||||
"scale_rebase_multiplier": 10,
|
||||
"lookup_range": [
|
||||
-2,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"logrows": 6,
|
||||
@@ -24,15 +24,18 @@
|
||||
"param_visibility": "Private",
|
||||
"div_rebasing": false,
|
||||
"rebase_frac_zero_constants": false,
|
||||
"check_mode": "UNSAFE"
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2
|
||||
},
|
||||
"num_rows": 16,
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
"total_const_size": 3,
|
||||
"total_dynamic_col_size": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"num_shuffles": 0,
|
||||
"total_shuffle_col_size": 0,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
[
|
||||
1,
|
||||
@@ -54,12 +57,19 @@
|
||||
]
|
||||
]
|
||||
},
|
||||
"required_lookups": [
|
||||
"ReLU"
|
||||
"required_lookups": [],
|
||||
"required_range_checks": [
|
||||
[
|
||||
-1,
|
||||
1
|
||||
],
|
||||
[
|
||||
0,
|
||||
127
|
||||
]
|
||||
],
|
||||
"required_range_checks": [],
|
||||
"check_mode": "UNSAFE",
|
||||
"version": "0.0.0",
|
||||
"num_blinding_factors": null,
|
||||
"timestamp": 1702474230544
|
||||
"timestamp": 1726429587279
|
||||
}
|
||||
Binary file not shown.
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}
|
||||
Reference in New Issue
Block a user