Compare commits

...

34 Commits

Author SHA1 Message Date
dante
14786acb95 feat: dynamic lookups (#727) 2024-03-01 01:44:45 +00:00
dante
80a3c44cb4 feat: lookup-less recip by default (#725) 2024-02-28 16:35:20 +00:00
dante
1656846d1a fix: transcript should serialize as lc flag (#726) 2024-02-26 22:02:47 +00:00
dante
88098b8190 fix!: cleanup felt serialization language in python and wasm (#724)
BREAKING CHANGE: python and wasm felt utilities have new names
2024-02-25 14:06:48 +00:00
dante
6c0c17c9be fix: include tol check in fwd pass (#723) 2024-02-23 01:28:59 +00:00
dante
bf69b16fc1 fix: rm optional bool flags (#722) 2024-02-21 12:45:42 +00:00
dante
74feb829da feat: parse command ast into flag strings (#720) 2024-02-21 00:38:26 +00:00
dante
d429e7edab fix: buffer data read and writes (#719) 2024-02-19 11:49:15 +00:00
dante
f0e5b82787 refactor: selectable key ser (#718) 2024-02-19 11:26:18 +00:00
dante
3f7261f50b fix: set buf capacity for witness, settings, proof (#717) 2024-02-16 21:59:20 +00:00
dante
678a249dcb feat: allow for reduced n srs for verification (#716) 2024-02-16 18:28:54 +00:00
dante
0291eb2d0f fix: reduce verbosity of common operations (#715) 2024-02-15 17:27:33 +00:00
dante
1b637a70b0 refactor: print_proof_hex is redundant with proof file (#713) 2024-02-14 15:25:28 +00:00
dante
abcd5380db feat: programmable buffer capacity (#712) 2024-02-13 15:49:14 +00:00
dante
076b737108 chore: allow for a max circuit area cap (#711) 2024-02-12 14:36:51 +00:00
dante
97d9832591 refactor: calibration for resources and accuracy over same scale range (#710) 2024-02-11 15:03:38 +00:00
dante
e0771683a6 chore: update h2 curves (#709) 2024-02-10 22:54:38 +00:00
dante
319c222307 chore: more descriptive debug logs on forward pass (#708) 2024-02-10 16:10:33 +00:00
dante
85ee6e7f9d refactor: use layout as the forward function (#707) 2024-02-08 21:15:46 +00:00
dante
4c8daf773c refactor: lookup-less layer norm (#706) 2024-02-07 21:19:17 +00:00
dante
80041ac523 refactor: equals argument without lookups (#705) 2024-02-07 14:20:13 +00:00
dante
2a1ee1102c refactor: range check recip (#703) 2024-02-05 14:42:26 +00:00
dante
95d4fd4a70 feat: power of 2 div using type system (#702) 2024-02-04 02:43:38 +00:00
dante
e0d3f4f145 fix: uncomparable values in acc table (#701) 2024-02-02 15:13:29 +00:00
dante
bceac2fab5 ci: make gpu tests single threaded (#700) 2024-01-31 18:19:29 +00:00
dante
04d7b5feaa chore: fold div_rebasing parameter into calibration (#699) 2024-01-31 10:03:12 +00:00
dante
45fd12a04f refactor!: make rebasing multiplicative by default (#698)
BREAKING CHANGE: adds a `required_range_checks` field to `cs`
2024-01-30 18:37:57 +00:00
dante
bc7c33190f feat: allow for separate vk render on-chain (#697) 2024-01-25 19:48:13 +00:00
dante
df72e01414 feat: make selector compression optional (#696) 2024-01-24 00:09:00 +00:00
Tobin South
172e26c00d fix: link for CLI auto-install (#695) 2024-01-22 13:00:27 +00:00
Jason Morton
11ac120f23 fix: large test numbering(#689)
Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-01-21 21:01:46 +00:00
Jseam
0fdd92e9f3 fix: move install_ezkl_cli.sh into main repo (#694)
Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-01-21 20:59:39 +00:00
Alexander Camuto
31f58056a5 chore: bump py03 2024-01-21 20:58:32 +00:00
dante
ddbcc1d2d8 fix: calibration should only consider local scales (#691) 2024-01-18 23:28:49 +00:00
83 changed files with 11048 additions and 15044 deletions

View File

@@ -6,12 +6,12 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
runs-on: self-hosted
runs-on: kaiju
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: nanoGPT Mock
@@ -23,6 +23,6 @@ jobs:
- name: Self Attention KZG prove and verify large tests
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_0_expects -- --include-ignored
- name: mobilenet Mock
run: cargo test --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
run: cargo test --release --verbose tests::large_mock_::large_tests_3_expects -- --include-ignored
- name: mobilenet KZG prove and verify large tests
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_2_expects -- --include-ignored
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_3_expects -- --include-ignored

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Checkout repo

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -189,7 +189,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -198,10 +198,8 @@ jobs:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Install wasm runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -214,7 +212,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -231,13 +229,15 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
- name: kzg inputs
@@ -286,7 +286,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -312,7 +312,9 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -343,18 +345,15 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Install wasm-server-runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -414,32 +413,32 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
fuzz-tests:
runs-on: ubuntu-latest-32-cores
@@ -448,7 +447,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -458,7 +457,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: fuzz tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
# - name: fuzz tests
@@ -471,7 +470,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -489,7 +488,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -501,12 +500,12 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -514,16 +513,16 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 8 -- --include-ignored
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -533,7 +532,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -544,7 +543,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -566,7 +565,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- name: Install solc
@@ -574,7 +573,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Run pytest
@@ -590,7 +589,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -601,6 +600,8 @@ jobs:
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
@@ -619,7 +620,7 @@ jobs:
python-version: "3.9"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -629,7 +630,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
@@ -643,14 +644,13 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -22,18 +22,15 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Install wasm-server-runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e

687
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,9 +15,9 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
halo2curves = { version = "0.1.0" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
@@ -28,16 +28,19 @@ thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "ac/lookup-modularity" }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
instant = { version = "0.1" }
@@ -51,9 +54,9 @@ plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
@@ -158,6 +161,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidi
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']

View File

@@ -64,8 +64,8 @@ More notebook tutorials can be found within `examples/notebooks`.
#### CLI
Install the CLI
```bash
curl https://hub.ezkl.xyz/install_ezkl_cli.sh | bash
``` shell
curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh | bash
```
https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-902e-20738ce0e9f0.mp4

View File

@@ -121,13 +121,16 @@ fn runcnvrl(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -90,13 +90,13 @@ fn rundot(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -94,13 +94,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -1,4 +1,5 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::table::Range;
use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
@@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;
const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;
@@ -111,13 +112,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -3,6 +3,7 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;
const BITS: (i128, i128) = (-8180, 8180);
const BITS: Range = (-8180, 8180);
static mut LEN: usize = 4;
static mut K: usize = 16;
@@ -114,13 +115,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {

View File

@@ -86,13 +86,13 @@ fn runsum(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -101,13 +101,16 @@ fn runsumpool(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -84,13 +84,13 @@ fn runadd(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -83,13 +83,13 @@ fn runpow(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -76,13 +76,13 @@ fn runposeidon(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -1,5 +1,6 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
@@ -14,7 +15,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;
@@ -90,13 +91,13 @@ fn runrelu(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -6,6 +6,7 @@ use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -489,6 +490,7 @@ pub fn runconv() {
strategy,
pi_for_real_prover,
&mut transcript,
params.n(),
);
assert!(verify.is_ok());

View File

@@ -309,7 +309,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
"print(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]))"
]
},
{
@@ -325,7 +325,7 @@
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider, utils\n",
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
@@ -338,7 +338,7 @@
"\n",
"def test_on_chain_data(res):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
" # Step 1: Prepare the data\n",
" # Step 2: Prepare and compile the contract.\n",
@@ -648,7 +648,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4
},

View File

@@ -441,7 +441,7 @@
"# Serialize calibration data into file:\n",
"json.dump(data, open(cal_data_path, 'w'))\n",
"\n",
"# Optimize for resources, we cap logrows at 17 to reduce setup and proving time, at the expense of accuracy\n",
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
"# You may want to increase the max logrows if accuracy is a concern\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
]
@@ -695,7 +695,7 @@
"formatted_output = \"[\"\n",
"for i, value in enumerate(proof[\"instances\"]):\n",
" for j, field_element in enumerate(value):\n",
" onchain_input_array.append(ezkl.vecu64_to_felt(field_element))\n",
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
" formatted_output += str(onchain_input_array[-1])\n",
" if j != len(value) - 1:\n",
" formatted_output += \", \"\n",
@@ -705,7 +705,7 @@
"# copy them over to remix and see if they verify\n",
"# What happens when you change a value?\n",
"print(\"pubInputs: \", formatted_output)\n",
"print(\"proof: \", \"0x\" + proof[\"proof\"])"
"print(\"proof: \", proof[\"proof\"])"
]
},
{

View File

@@ -300,13 +300,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
"\n",
"def setup(i):\n",
" print(\"Setting up split model \"+str(i))\n",
" # file names\n",
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
@@ -342,12 +343,12 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" compress_selectors=True,\n",
" )\n",
"\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" \n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
@@ -383,7 +384,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"for-aggr\",\n",
" )\n",
"\n",
@@ -413,7 +413,6 @@
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
" assert res == True\n",
@@ -442,7 +441,7 @@
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
]
}
],

File diff suppressed because one or more lines are too long

View File

@@ -302,7 +302,7 @@
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" \n",
"\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
@@ -330,14 +330,14 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"for-aggr\",\n",
" )\n",
"\n",
" print(res)\n",
" res_1_proof = res[\"proof\"]\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" # Verify the proof\n",
" # # Verify the proof\n",
" if i > 0:\n",
" print(\"swapping commitments\")\n",
" # swap the proof commitments if we are not the first model\n",
@@ -356,12 +356,19 @@
"\n",
" res = ezkl.swap_proof_commitments(proof_path, witness_path)\n",
" print(res)\n",
" \n",
" # load proof and then print \n",
" proof = json.load(open(proof_path, 'r'))\n",
" res_2_proof = proof[\"hex_proof\"]\n",
" # show diff in hex strings\n",
" print(res_1_proof)\n",
" print(res_2_proof)\n",
" assert res_1_proof == res_2_proof\n",
"\n",
" res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
" assert res == True\n",
@@ -439,7 +446,7 @@
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
]
}
],

View File

@@ -78,7 +78,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
@@ -122,8 +122,8 @@
"# Loop through each element in the y tensor\n",
"for e in y_input:\n",
" # Apply the custom function and append the result to the list\n",
" print(ezkl.float_to_vecu64(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 7)])[0])\n",
" print(ezkl.float_to_felt(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 7)])[0])\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",
@@ -343,7 +343,7 @@
"# we force the output to be 0 this corresponds to the set membership test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [[0, 0, 0, 0]]\n",
"witness[\"outputs\"][0] = [\"0000000000000000000000000000000000000000000000000000000000000000\"]\n",
"json.dump(witness, open(witness_path, \"w\"))\n",
"\n",
"witness = json.load(open(witness_path, \"r\"))\n",
@@ -353,7 +353,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" witness_path = witness_path,\n",
" )\n",
"\n",
@@ -520,4 +519,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -210,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "b1c561a8",
"metadata": {},
"outputs": [],

View File

@@ -126,7 +126,7 @@
"# Loop through each element in the y tensor\n",
"for e in user_preimages:\n",
" # Apply the custom function and append the result to the list\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 0)])[0])\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 0)])[0])\n",
"\n",
"users_t = torch.tensor(user_preimages)\n",
"users_t = users_t.reshape(1, 6)\n",
@@ -303,7 +303,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))"
]
},
@@ -417,7 +417,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))\n"
]
},
@@ -510,7 +510,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -633,7 +633,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -664,7 +664,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
")"
]
},

View File

@@ -520,7 +520,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{
@@ -636,7 +636,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -25,17 +25,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"voice_data_dir: .\n"
]
}
],
"outputs": [],
"source": [
"\n",
"import os\n",
@@ -43,7 +35,7 @@
"\n",
"voice_data_dir = os.environ.get('VOICE_DATA_DIR')\n",
"\n",
"# if is none set to \"\" \n",
"# if is none set to \"\"\n",
"if voice_data_dir is None:\n",
" voice_data_dir = \"\"\n",
"\n",
@@ -637,7 +629,7 @@
"source": [
"\n",
"\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\")\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -908,7 +900,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -503,11 +503,11 @@
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"\n",
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][1], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][1], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][3], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][3], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
]
}

View File

@@ -0,0 +1,39 @@
from torch import nn
import torch
import json
class Circuit(nn.Module):
def __init__(self, inplace=False):
super(Circuit, self).__init__()
def forward(self, x):
return x/ 10000
circuit = Circuit()
x = torch.empty(1, 8).random_(0, 2)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

170
install_ezkl_cli.sh Normal file
View File

@@ -0,0 +1,170 @@
#!/usr/bin/env bash
set -e
BASE_DIR=${XDG_CONFIG_HOME:-$HOME}
EZKL_DIR=${EZKL_DIR-"$BASE_DIR/.ezkl"}
# Create the .ezkl bin directory if it doesn't exit
mkdir -p $EZKL_DIR
# Store the correct profile file (i.e. .profile for bash or .zshenv for ZSH).
case $SHELL in
*/zsh)
PROFILE=${ZDOTDIR-"$HOME"}/.zshenv
PREF_SHELL=zsh
;;
*/bash)
PROFILE=$HOME/.bashrc
PREF_SHELL=bash
;;
*/fish)
PROFILE=$HOME/.config/fish/config.fish
PREF_SHELL=fish
;;
*/ash)
PROFILE=$HOME/.profile
PREF_SHELL=ash
;;
*)
echo "NOTICE: Shell could not be detected, you will need to manually add ${EZKL_DIR} to your PATH."
esac
# Check for non standard installation of ezkl
if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; then
echo "ezkl is installed in a non-standard directory, $(which ezkl). To use the automated installer, remove the existing ezkl from path to prevent conflicts"
exit 1
fi
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi
# Install latest ezkl version
# Get the right release URL
if [ -z "$1" ]
then
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/latest"
echo "No version tags provided, installing the latest ezkl version"
else
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/tags/$1"
echo "Installing ezkl version $1"
fi
PLATFORM=""
case "$(uname -s)" in
Darwin*)
PLATFORM="macos"
;;
Linux*Microsoft*)
PLATFORM="linux"
;;
Linux*)
PLATFORM="linux"
;;
CYGWIN*|MINGW*|MINGW32*|MSYS*)
PLATFORM="windows-msvc"
;;
*)
echo "Platform is not supported. If you would need support for the platform please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
;;
esac
# Check arch
ARCHITECTURE="$(uname -m)"
if [ "${ARCHITECTURE}" = "x86_64" ]; then
# Redirect stderr to /dev/null to avoid printing errors if non Rosetta.
if [ "$(sysctl -n sysctl.proc_translated 2>/dev/null)" = "1" ]; then
ARCHITECTURE="arm64" # Rosetta.
else
ARCHITECTURE="amd64" # Intel.
fi
elif [ "${ARCHITECTURE}" = "arm64" ] ||[ "${ARCHITECTURE}" = "aarch64" ]; then
ARCHITECTURE="aarch64" # Arm.
elif [ "${ARCHITECTURE}" = "amd64" ]; then
ARCHITECTURE="amd64" # Amd
else
echo "Architecture is not supported. If you would need support for the architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
# Remove existing ezkl
echo "Removing old ezkl binary if it exists"
[ -e file ] && rm file
# download the release and unpack the right tarball
if [ "$PLATFORM" == "windows-msvc" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-windows-msvc.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
elif [ "$PLATFORM" == "macos" ]; then
if [ "$ARCHITECTURE" == "aarch64" ] || [ "$ARCHITECTURE" == "arm64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos-aarch64.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
else
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
fi
elif [ "$PLATFORM" == "linux" ]; then
if [ "${ARCHITECTURE}" = "amd64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
else
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
else
echo "Platform and Architecture is not supported. If you would need support for the platform and architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
echo && echo "Successfully downloaded ezkl at ${EZKL_DIR}"
echo "We detected that your preferred shell is ${PREF_SHELL} and added ezkl to PATH. Run 'source ${PROFILE}' or start a new terminal session to use ezkl."

View File

@@ -11,8 +11,8 @@ use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
use log::{error, info};
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "icicle")]
@@ -25,6 +25,7 @@ use std::error::Error;
pub async fn main() -> Result<(), Box<dyn Error>> {
let args = Cli::parse();
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
@@ -32,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
} else {
info!("Running with CPU");
}
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
let res = run(args.command).await;
match &res {
Ok(_) => info!("succeeded"),
@@ -44,7 +45,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
#[cfg(target_arch = "wasm32")]
pub fn main() {}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
fn banner() {
let ell: Vec<&str> = vec![
"for Neural Networks",

View File

@@ -41,7 +41,7 @@ pub struct KZGChip {
}
impl KZGChip {
/// Returns the number of inputs to the hash function
/// Commit to the message using the KZG commitment scheme
pub fn commit(
message: Vec<Fp>,
degree: u32,
@@ -219,7 +219,7 @@ mod tests {
};
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}
@@ -240,6 +240,6 @@ mod tests {
message: message.into(),
};
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}

View File

@@ -15,7 +15,7 @@ use halo2_proofs::{
Instance, Selector, TableColumn,
},
};
use log::{trace, warn};
use log::{debug, trace};
/// A simple [`FloorPlanner`] that performs minimal optimizations.
#[derive(Debug)]
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
Error::Synthesis
})?;
if !self.regions.contains_key(&index) {
warn!("spawning module {}", index)
debug!("spawning module {}", index)
};
self.current_module = index;
}

View File

@@ -499,7 +499,7 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
#[test]
@@ -518,7 +518,7 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
#[test]
@@ -551,7 +551,7 @@ mod tests {
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}
@@ -573,6 +573,6 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}

View File

@@ -125,8 +125,8 @@ impl BaseOp {
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 1,
BaseOp::IsBoolean => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -16,10 +16,14 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use crate::{
circuit::ops::base::BaseOp,
circuit::{table::Table, utils},
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
@@ -58,6 +62,22 @@ pub enum CheckMode {
UNSAFE,
}
impl std::fmt::Display for CheckMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckMode::SAFE => write!(f, "safe"),
CheckMode::UNSAFE => write!(f, "unsafe"),
}
}
}
impl ToFlags for CheckMode {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<String> for CheckMode {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -80,6 +100,19 @@ pub struct Tolerance {
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
@@ -165,6 +198,9 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub lookup_input: VarTensor,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// The VarTensor reserved for dynamic lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub dynamic_lookup_tables: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_output: VarTensor,
@@ -174,8 +210,16 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub dynamic_lookup_selectors: BTreeMap<(usize, usize), Vec<Selector>>,
///
pub dynamic_table_selectors: Vec<Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -191,10 +235,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
lookup_input: dummy_var.clone(),
output: dummy_var.clone(),
lookup_output: dummy_var.clone(),
dynamic_lookup_tables: vec![VarTensor::dummy(col_size, 2), dummy_var.clone()],
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
dynamic_table_selectors: vec![],
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
@@ -267,9 +316,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
let constraints = match base_op {
BaseOp::IsBoolean => {
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
let output = expected_output[base_op.constraint_idx()].clone();
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
BaseOp::IsZero => vec![qis[1].clone()],
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
@@ -325,11 +385,16 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
dynamic_lookup_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
dynamic_table_selectors: vec![],
dynamic_lookup_tables: vec![],
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
check_mode,
_marker: PhantomData,
@@ -344,15 +409,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
input: &VarTensor,
output: &VarTensor,
index: &VarTensor,
lookup_range: (i128, i128),
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
}
@@ -462,10 +525,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((nl.clone(), x, y), multi_col_selector);
self.lookup_selectors
.insert((nl.clone(), x, y), multi_col_selector);
}
}
self.lookup_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
@@ -482,6 +545,187 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_dynamic_lookup(
&mut self,
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 2],
tables: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookup_selectors
.entry((x, y))
.or_default()
.push(s_lookup);
}
}
self.dynamic_table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookup_tables.is_empty() {
debug!("assigning dynamic lookup table");
self.dynamic_lookup_tables = tables.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
&mut self,
cs: &mut ConstraintSystem<F>,
input: &VarTensor,
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}
// we borrow mutably twice so we need to do this dance
let range_check =
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
let len = range_check.selector_constructor.degree;
let multi_col_selector = cs.complex_selector();
for (col_idx, input_col) in range_check.inputs.iter().enumerate() {
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(multi_col_selector);
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let default_x = range_check.get_first_element(col_idx);
let col_expr = sel.clone()
* range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
let multiplier = range_check
.selector_constructor
.get_selector_val_at_idx(col_idx);
let not_expr = Expression::Constant(multiplier) - col_expr.clone();
res.extend([(
col_expr.clone() * input_query.clone()
+ not_expr.clone() * Expression::Constant(default_x),
*input_col,
)]);
log::trace!("---------------- col {:?} ------------------", col_idx,);
log::trace!("expr: {:?}", col_expr,);
log::trace!("multiplier: {:?}", multiplier);
log::trace!("not_expr: {:?}", not_expr);
log::trace!("default x: {:?}", default_x);
res
});
}
self.range_check_selectors
.insert((range, x, y), multi_col_selector);
}
}
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
}
if let VarTensor::Empty = self.lookup_index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
}
Ok(())
}
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
@@ -500,6 +744,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}
/// layout_range_checks must be called before layout.
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;
}
}
Ok(())
}
/// Assigns variables to the regions created when calling `configure`.
/// # Arguments
/// * `values` - The explicit values to the operations.

View File

@@ -1,11 +1,11 @@
use super::*;
use crate::{
circuit::{self, layouts, utils, Tolerance},
circuit::{layouts, utils, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
// import run args from model
@@ -13,6 +13,15 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
use_range_check_for_int: bool,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
},
@@ -59,14 +68,6 @@ pub enum HybridOp {
dim: usize,
num_classes: usize,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
@@ -74,7 +75,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
_ => vec![],
}
}
@@ -87,142 +88,42 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let (res, intermediate_lookups) = match &self {
HybridOp::ReduceMax { axes, .. } => {
let res = tensor::ops::max_axes(&x, axes)?;
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::ReduceMin { axes, .. } => {
let res = tensor::ops::min_axes(&x, axes)?;
let min_plus_one =
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(min(x + 1) - x)
let inter_1 = (min_plus_one - x.clone())?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let inter =
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
inter_equals.extend(inter);
(res.clone(), inter_equals)
}
HybridOp::ReduceArgMin { dim } => {
let res = tensor::ops::argmin_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let inter =
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
inter_equals.extend(inter);
(res.clone(), inter_equals)
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather(&x, idx, *dim)?;
(res.clone(), vec![])
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), inter_equals)
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => {
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
(res.clone(), inter_equals)
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => {
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
let mut inter_equals = x
.clone()
.into_iter()
.flat_map(|elem| {
tensor::ops::equals(&res, &vec![elem].into_iter().into())
.unwrap()
.1
})
.collect::<Vec<_>>();
// sort in descending order and take pairwise differences
inter_equals.push(
x.into_iter()
.sorted()
.tuple_windows()
.map(|(a, b)| b - a)
.into(),
);
(res.clone(), inter_equals)
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
(res.clone(), vec![])
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let src = inputs[1].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
(res.clone(), vec![])
} else {
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
let src = inputs[2].clone().map(|x| felt_to_i128(x));
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => {
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
vec![inter_1, inter_2],
)
}
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
@@ -234,10 +135,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
(
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
vec![],
)
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
@@ -264,14 +162,26 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output,
intermediate_lookups,
})
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
match self {
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => format!(
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
input_scale, output_scale, use_range_check_for_int
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::SumPool {
padding,
stride,
@@ -306,8 +216,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::TopK { k, dim, largest } => {
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
}
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
}
@@ -335,6 +243,55 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
*kernel_shape,
*normalized,
)?,
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
if input_scale.0.fract() == 0.0
&& output_scale.0.fract() == 0.0
&& *use_range_check_for_int
{
layouts::recip(
config,
region,
values[..].try_into()?,
i128_to_felt(input_scale.0 as i128),
i128_to_felt(output_scale.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Recip {
input_scale: *input_scale,
output_scale: *output_scale,
},
)?
}
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Div { denom: *denom },
)?
}
}
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -342,26 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,
@@ -422,86 +360,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
match self {
HybridOp::ReduceMax { .. }
| HybridOp::ReduceMin { .. }
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
HybridOp::Softmax { scale, .. } => {
vec![
LookupOp::Exp { scale: *scale },
LookupOp::Recip {
scale: scale.0.powf(2.0).into(),
},
]
}
HybridOp::RangeCheck(tol) => {
let mut lookups = vec![];
if tol.val > 0.0 {
let scale_squared = tol.scale.0.powf(2.0);
lookups.extend([
LookupOp::Recip {
scale: scale_squared.into(),
},
LookupOp::GreaterThan {
a: circuit::utils::F32((tol.val * scale_squared) / 100.0),
},
]);
}
lookups
}
HybridOp::Greater { .. } | HybridOp::Less { .. } => {
vec![LookupOp::GreaterThan {
a: circuit::utils::F32(0.),
}]
}
HybridOp::GreaterEqual { .. } | HybridOp::LessEqual { .. } => {
vec![LookupOp::GreaterThanEqual {
a: circuit::utils::F32(0.),
}]
}
HybridOp::TopK { .. } => {
vec![
LookupOp::GreaterThan {
a: circuit::utils::F32(0.),
},
LookupOp::KroneckerDelta,
]
}
HybridOp::Gather {
constant_idx: None, ..
}
| HybridOp::OneHot { .. }
| HybridOp::GatherElements {
constant_idx: None, ..
}
| HybridOp::ScatterElements {
constant_idx: None, ..
}
| HybridOp::Equals { .. } => {
vec![LookupOp::KroneckerDelta]
}
HybridOp::ReduceArgMax { .. } | HybridOp::ReduceArgMin { .. } => {
vec![LookupOp::ReLU, LookupOp::KroneckerDelta]
}
HybridOp::SumPool {
kernel_shape,
normalized: true,
..
} => {
vec![LookupOp::Div {
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
}]
}
_ => vec![],
}
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
}

View File

@@ -19,7 +19,7 @@ use super::{
};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::i128_to_felt,
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
@@ -51,6 +51,250 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
total_len
}
/// Same as div but splits the division into N parts
pub fn loop_div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
divisor: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if divisor == F::ONE {
return Ok(value[0].clone());
}
// if integer val is divisible by 2, we can use a faster method and div > F::S
let mut divisor = divisor;
let mut num_parts = 1;
while felt_to_i128(divisor) % 2 == 0 && felt_to_i128(divisor) > (2_i128.pow(F::S - 4)) {
divisor = i128_to_felt(felt_to_i128(divisor) / 2);
num_parts += 1;
}
let output = div(config, region, value, divisor)?;
if num_parts == 1 {
return Ok(output);
}
let divisor_int = 2_i128.pow(num_parts - 1);
let divisor_felt = i128_to_felt(divisor_int);
if divisor_int <= 2_i128.pow(F::S - 3) {
div(config, region, &[output], divisor_felt)
} else {
// keep splitting the divisor until it satisfies the condition
loop_div(config, region, &[output], divisor_felt)
}
}
/// Div accumulated layout
pub fn div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
div: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if div == F::ONE {
return Ok(value[0].clone());
}
let input = value[0].clone();
let input_dims = input.dims();
let range_check_bracket = felt_to_i128(div) / 2;
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
divisor.set_visibility(&crate::graph::Visibility::Fixed);
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
region.increment(divisor.len());
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
claimed_output.reshape(input_dims)?;
region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
let product = pairwise(
config,
region,
&[claimed_output.clone(), divisor.clone()],
BaseOp::Mult,
)?;
let diff_with_input = pairwise(
config,
region,
&[product.clone(), input.clone()],
BaseOp::Sub,
)?;
range_check(
config,
region,
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
)?;
Ok(claimed_output)
}
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert is boolean
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
// get values where input is 0
let zero_mask = equals_zero(config, region, input)?;
let one_minus_zero_mask = pairwise(
config,
region,
&[
zero_mask.clone(),
ValTensor::from(Tensor::from([ValType::Constant(F::ONE)].into_iter())),
],
BaseOp::Sub,
)?;
let zero_inverse_val = pairwise(
config,
region,
&[
zero_mask,
ValTensor::from(Tensor::from(
[ValType::Constant(i128_to_felt(zero_inverse_val))].into_iter(),
)),
],
BaseOp::Mult,
)?;
pairwise(
config,
region,
&[one_minus_zero_mask, zero_inverse_val],
BaseOp::Add,
)
}
/// recip accumulated layout
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if output_scale == F::ONE || output_scale == F::ZERO {
return recip_int(config, region, value);
}
let input = value[0].clone();
let input_dims = input.dims();
let integer_input_scale = felt_to_i128(input_scale);
let integer_output_scale = felt_to_i128(output_scale);
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
let input_scale_ratio =
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
let range_check_bracket = range_check_len / 2;
let is_assigned = !input.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::recip(
&input_evals,
felt_to_i128(input_scale) as f64,
felt_to_i128(output_scale) as f64,
)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
claimed_output.reshape(input_dims)?;
let claimed_output = region.assign(&config.output, &claimed_output)?;
region.increment(claimed_output.len());
// this is now of scale 2 * scale
let product = pairwise(
config,
region,
&[claimed_output.clone(), input.clone()],
BaseOp::Mult,
)?;
// divide by input_scale
let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?;
let zero_inverse_val =
tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0];
let zero_inverse =
Tensor::from([ValType::Constant(i128_to_felt::<F>(zero_inverse_val))].into_iter());
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
let equal_inverse_mask = equals(
config,
region,
&[claimed_output.clone(), zero_inverse.into()],
)?;
// assert the two masks are equal
enforce_equality(
config,
region,
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;
let unit_scale = Tensor::from([ValType::Constant(i128_to_felt(range_check_len))].into_iter());
let unit_mask = pairwise(
config,
region,
&[equal_zero_mask, unit_scale.into()],
BaseOp::Mult,
)?;
// now add the unit mask to the rebased_div
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[rebased_offset_div],
&(range_check_bracket, 3 * range_check_bracket),
)?;
Ok(claimed_output)
}
/// Dot product accumulated layout
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -92,7 +336,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
let mut assigned_len = 0;
for (i, input) in values.iter_mut().enumerate() {
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_with_duplication(
&config.inputs[i],
@@ -648,14 +892,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
let assigned_input = region.assign(&config.inputs[0], &input)?;
// now assert all elems are 0 or 1
let assigned_output = region.assign(&config.inputs[1], &output)?;
if !region.is_dummy() {
for i in 0..assigned_output.len() {
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
}
}
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
let sum = sum(config, region, &[assigned_output.clone()])?;
@@ -678,6 +915,78 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
Ok(assigned_output)
}
/// Dynamic lookup
pub fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err("lookups must be same length".into());
}
// if not all inputs same length err
if tables[0].len() != tables[1].len() {
return Err("inputs must be same length".into());
}
// now assert the inputs of a smaller length than the lookups
if tables[0].len() > tables[0].len() {
return Err("inputs must be smaller length than dynamic lookups".into());
}
let (lookup_0, lookup_1) = (lookups[0].clone(), lookups[1].clone());
let (table_0, table_1) = (tables[0].clone(), tables[1].clone());
let table_0 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[0], &table_0)?;
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookup_tables[1], &table_1)?;
let table_len = table_0.len();
let lookup_0 = region.assign(&config.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
if !region.is_dummy() {
(0..table_len)
.map(|i| {
let dynamic_lookup_index = region.dynamic_lookup_index();
let table_selector = config.dynamic_table_selectors[dynamic_lookup_index];
let (_, _, z) = config.dynamic_lookup_tables[0]
.cartesian_coord(region.dynamic_lookup_col_coord() + i);
region.enable(Some(&table_selector), z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if !region.is_dummy() {
// Enable the selectors
(0..lookup_len)
.map(|i| {
let (x, y, z) = config.inputs[0].cartesian_coord(region.linear_coord() + i);
let dynamic_lookup_index = region.dynamic_lookup_index();
let lookup_selector = config
.dynamic_lookup_selectors
.get(&(x, y))
.ok_or("missing selectors")?[dynamic_lookup_index];
region.enable(Some(&lookup_selector), z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment_dynamic_lookup_col_coord(table_len);
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
Ok((lookup_0, lookup_1))
}
/// One hot accumulated layout
pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1027,7 +1336,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1096,7 +1405,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width)?;
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1560,10 +1869,42 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let diff = pairwise(config, region, values, BaseOp::Sub)?;
equals_zero(config, region, &[diff])
}
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
/// Equality boolean operation
pub fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let values = values[0].clone();
let values_inverse = values.inverse()?;
let product_values_and_invert = pairwise(
config,
region,
&[values.clone(), values_inverse],
BaseOp::Mult,
)?;
Ok(res)
// constant of 1
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
ones.set_visibility(&crate::graph::Visibility::Fixed);
// subtract
let output = pairwise(
config,
region,
&[ones.into(), product_values_and_invert],
BaseOp::Sub,
)?;
// take the product of diff and output
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
Ok(output)
}
/// Xor boolean operation
@@ -1627,21 +1968,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
.into();
// make sure mask is boolean
let assigned_mask = region.assign(&config.inputs[1], mask)?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_mask.len())
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(assigned_mask.len());
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
@@ -1735,17 +2062,15 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
let shape = &res[0].dims()[2..];
let mut last_elem = res[1..]
.iter()
.fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?;
.try_fold(res[0].clone(), |acc, elem| acc.concat(elem.clone()))?;
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
last_elem = nonlinearity(
last_elem = loop_div(
config,
region,
&[last_elem],
&LookupOp::Div {
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
},
F::from((kernel_shape.0 * kernel_shape.1) as u64),
)?;
}
Ok(last_elem)
@@ -1837,15 +2162,6 @@ pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std:
)));
}
if has_bias {
let bias = &inputs[2];
if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) {
return Err(Box::new(TensorError::DimMismatch(
"deconv bias".to_string(),
)));
}
}
let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]);
let null_val = ValType::Constant(F::ZERO);
@@ -2251,18 +2567,60 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = region.assign(&config.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.output.cartesian_coord(index);
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
Ok(output)
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = region.assign(&config.inputs[1], &values[0])?;
let output = if assign || !values[0].get_const_indices()?.is_empty() {
// get zero constants indices
let output = region.assign(&config.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.output.cartesian_coord(index);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
@@ -2270,7 +2628,6 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(output.len());
Ok(output)
}
@@ -2313,6 +2670,85 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// layout for range check.
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
range: &crate::circuit::table::Range,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_range_check(*range)?;
// time the entire operation
let timer = instant::Instant::now();
let x = values[0].clone();
let w = region.assign(&config.lookup_input, &x)?;
let assigned_len = x.len();
let is_dummy = region.is_dummy();
let table_index: ValTensor<F> = w
.get_inner_tensor()?
.par_enum_map(|_, e| {
Ok::<ValType<F>, TensorError>(if let Some(f) = e.get_felt_eval() {
let col_idx = if !is_dummy {
let table = config
.range_checks
.get(range)
.ok_or(TensorError::TableLookupError)?;
table.get_col_index(f)
} else {
F::ZERO
};
Value::known(col_idx).into()
} else {
Value::<F>::unknown().into()
})
})?
.into();
region.assign(&config.lookup_index, &table_index)?;
if !is_dummy {
(0..assigned_len)
.map(|i| {
let (x, y, z) = config
.lookup_input
.cartesian_coord(region.linear_coord() + i);
let selector = config.range_check_selectors.get(&(*range, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
}
region.increment(assigned_len);
let elapsed = timer.elapsed();
trace!(
"range check {:?} layout took {:?}, row: {:?}",
range,
elapsed,
region.row()
);
Ok(w)
}
/// layout for nonlinearity check.
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2320,6 +2756,8 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
nl: &LookupOp,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_lookup(nl.clone(), values)?;
// time the entire operation
let timer = instant::Instant::now();
@@ -2401,22 +2839,6 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// mean function layout
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let x = &values[0];
let sum_x = sum(config, region, &[x.clone()])?;
let nl = LookupOp::Div {
denom: utils::F32((scale * x.len()) as f32),
};
nonlinearity(config, region, &[sum_x], &nl)
}
/// Argmax
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2529,24 +2951,8 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
)?;
// relu(x - max(x - 1))
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
let len = relu.dims().iter().product();
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
region.assign(&config.inputs[1], &relu)?;
if !region.is_dummy() {
(0..len)
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(len);
// constraining relu(x - max(x - 1)) = 0/1
boolean_identity(config, region, &[relu.clone()], false)?;
// sum(relu(x - max(x - 1)))
let sum_relu = sum(config, region, &[relu])?;
@@ -2557,13 +2963,7 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
// constraining 1 - sum(relu(x - max(x - 1))) = 0
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(relu_one_minus_sum_relu.len());
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
Ok(assigned_max_val)
}
@@ -2608,23 +3008,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
// relu(min(x + 1) - x)
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
let len = relu.dims().iter().product();
region.assign(&config.inputs[1], &relu)?;
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
if !region.is_dummy() {
(0..len)
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(len);
// constraining relu(min(x + 1) - x) = 0/1
boolean_identity(config, region, &[relu.clone()], false)?;
// sum(relu(min(x + 1) - x))
let sum_relu = sum(config, region, &[relu])?;
@@ -2635,14 +3020,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
let relu_one_minus_sum_relu =
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
// constraining product to 0
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(relu_one_minus_sum_relu.len());
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
Ok(assigned_min_val)
}
@@ -2783,15 +3162,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let inv_denom = nonlinearity(
config,
region,
&[denom],
// we set to input scale + output_scale so the output scale is output)scale
&LookupOp::Recip {
scale: scale.0.powf(2.0).into(),
},
)?;
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
@@ -2814,58 +3186,56 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
return enforce_equality(config, region, values);
}
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let mut values = [values[0].clone(), values[1].clone()];
let scale_squared = scale.0.powf(2.0);
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
values[0] = region.assign(&config.inputs[0], &values[0])?;
values[1] = region.assign(&config.inputs[1], &values[1])?;
let total_assigned_0 = values[0].len();
let total_assigned_1 = values[1].len();
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
region.increment(total_assigned);
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// integer scale
let int_scale = scale.0 as i128;
// felt scale
let felt_scale = i128_to_felt(int_scale);
// range check len capped at 2^(S-3) and make it divisible 2
let range_check_bracket = std::cmp::min(
utils::F32(scale.0),
utils::F32(2_f32.powf((F::S - 5) as f32)),
)
.0;
let range_check_bracket_int = range_check_bracket as i128;
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
let input_scale_ratio = ((scale.0.powf(2.0) / range_check_bracket) * tol) as i128 / 2 * 2;
let recip = recip(
config,
region,
&[values[0].clone()],
&LookupOp::Recip {
scale: scale_squared.into(),
},
felt_scale,
felt_scale * F::from(100),
)?;
log::debug!("recip: {}", recip.show());
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
// Use the greater than look up table to check if the percent error is within the tolerance for upper bound
let tol = tol / 100.0;
let upper_bound = nonlinearity(
log::debug!("product: {}", product.show());
let rebased_product = loop_div(config, region, &[product], i128_to_felt(input_scale_ratio))?;
log::debug!("rebased_product: {}", rebased_product.show());
// check that it is within the tolerance range
range_check(
config,
region,
&[product.clone()],
&LookupOp::GreaterThan {
a: utils::F32(tol * scale_squared),
},
)?;
// Negate the product
let neg_product = neg(config, region, &[product])?;
// Use the greater than look up table to check if the percent error is within the tolerance for lower bound
let lower_bound = nonlinearity(
config,
region,
&[neg_product],
&LookupOp::GreaterThan {
a: utils::F32(tol * scale_squared),
},
)?;
// Add the lower_bound and upper_bound
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
// Assign the sum tensor to the inputs
region.assign(&config.inputs[1], &sum)?;
// Constrain the sum to be all zeros
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(sum.len());
Ok(sum)
&[rebased_product],
&(-range_check_bracket_int, range_check_bracket_int),
)
}

View File

@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::{
circuit::{layouts, utils},
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
graph::{multiplier_to_scale, scale_to_multiplier},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -17,47 +17,117 @@ use halo2curves::ff::PrimeField;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Abs,
Div { denom: utils::F32 },
Cast { scale: utils::F32 },
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
ReLU,
Max { scale: utils::F32, a: utils::F32 },
Min { scale: utils::F32, a: utils::F32 },
Ceil { scale: utils::F32 },
Floor { scale: utils::F32 },
Round { scale: utils::F32 },
RoundHalfToEven { scale: utils::F32 },
Sqrt { scale: utils::F32 },
Rsqrt { scale: utils::F32 },
Recip { scale: utils::F32 },
LeakyReLU { slope: utils::F32 },
Sigmoid { scale: utils::F32 },
Ln { scale: utils::F32 },
Exp { scale: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
ACosh { scale: utils::F32 },
Sin { scale: utils::F32 },
ASin { scale: utils::F32 },
Sinh { scale: utils::F32 },
ASinh { scale: utils::F32 },
Tan { scale: utils::F32 },
ATan { scale: utils::F32 },
Tanh { scale: utils::F32 },
ATanh { scale: utils::F32 },
Erf { scale: utils::F32 },
GreaterThan { a: utils::F32 },
LessThan { a: utils::F32 },
GreaterThanEqual { a: utils::F32 },
LessThanEqual { a: utils::F32 },
Max {
scale: utils::F32,
a: utils::F32,
},
Min {
scale: utils::F32,
a: utils::F32,
},
Ceil {
scale: utils::F32,
},
Floor {
scale: utils::F32,
},
Round {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
Rsqrt {
scale: utils::F32,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Ln {
scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Cos {
scale: utils::F32,
},
ACos {
scale: utils::F32,
},
Cosh {
scale: utils::F32,
},
ACosh {
scale: utils::F32,
},
Sin {
scale: utils::F32,
},
ASin {
scale: utils::F32,
},
Sinh {
scale: utils::F32,
},
ASinh {
scale: utils::F32,
},
Tan {
scale: utils::F32,
},
ATan {
scale: utils::F32,
},
Tanh {
scale: utils::F32,
},
ATanh {
scale: utils::F32,
},
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
Sign,
KroneckerDelta,
Pow { scale: utils::F32, a: utils::F32 },
Pow {
scale: utils::F32,
a: utils::F32,
},
}
impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> (i128, i128) {
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
(-range, range)
@@ -120,7 +190,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
LookupOp::LeakyReLU { slope: a } => {
@@ -150,10 +227,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output })
}
/// Returns the name of the operation
@@ -169,11 +243,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
} => format!(
"RECIP(input_scale={}, output_scale={})",
input_scale, output_scale
),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
@@ -220,12 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { scale } => {
let mut out_scale = inputs_scale[0];
out_scale +=
multiplier_to_scale(scale.0 as f64 / scale_to_multiplier(out_scale).powf(2.0));
out_scale
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::Sign
| LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
@@ -237,10 +312,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
Ok(scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
vec![self.clone()]
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
}

View File

@@ -29,7 +29,6 @@ pub mod region;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub(crate) output: Tensor<F>,
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
@@ -55,11 +54,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
vec![]
}
/// Returns the lookups required by the operation.
fn required_lookups(&self) -> Vec<LookupOp> {
vec![]
}
/// Returns true if the operation is an input.
fn is_input(&self) -> bool {
false
@@ -183,7 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
intermediate_lookups: vec![],
})
}
@@ -206,6 +199,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
config,
region,
values[..].try_into()?,
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
@@ -308,10 +302,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {

View File

@@ -1,5 +1,6 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -9,6 +10,14 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -33,7 +42,9 @@ pub enum PolyOp {
Sub,
Neg,
Mult,
Identity,
Identity {
out_scale: Option<crate::Scale>,
},
Reshape(Vec<usize>),
MoveAxis {
source: usize,
@@ -79,13 +90,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
PolyOp::Resize { .. } => "RESIZE".into(),
PolyOp::Iff => "IFF".into(),
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
PolyOp::Identity => "IDENTITY".into(),
PolyOp::Identity { out_scale } => {
format!("IDENTITY (out_scale={:?})", out_scale)
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
@@ -135,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity => Ok(inputs[0].clone()),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
@@ -199,14 +214,36 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
}?;
Ok(ForwardResult {
output: res,
intermediate_lookups: vec![],
})
Ok(ForwardResult { output: res })
}
fn layout(
@@ -237,7 +274,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Neg => layouts::neg(config, region, values[..].try_into()?)?,
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
PolyOp::Einsum { equation } => layouts::einsum(config, region, &values, equation)?,
PolyOp::Einsum { equation } => layouts::einsum(config, region, values, equation)?,
PolyOp::Sum { axes } => {
layouts::sum_axes(config, region, values[..].try_into()?, axes)?
}
@@ -247,6 +284,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -264,7 +321,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -290,12 +347,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MultiBroadcastTo { .. } => in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Neg => in_scales[0],
PolyOp::MoveAxis { .. } => in_scales[0],
PolyOp::Downsample { .. } => in_scales[0],
PolyOp::Resize { .. } => in_scales[0],
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {
let mut scale = in_scales[0];
@@ -327,9 +379,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_scale
}
PolyOp::Add => {
let mut scale_a = 0;
let scale_b = in_scales[0];
scale_a += in_scales[1];
let scale_a = in_scales[0];
let scale_b = in_scales[1];
assert_eq!(scale_a, scale_b);
scale_a
}
@@ -339,26 +390,23 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
scale += in_scales[1];
scale
}
PolyOp::Identity => in_scales[0],
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pad(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Pack(_, _) => in_scales[0],
PolyOp::GlobalSumPool => in_scales[0],
PolyOp::Concat { axis: _ } => in_scales[0],
PolyOp::Slice { .. } => in_scales[0],
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(
self,
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
) {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
vec![0, 2]
} else {
vec![]
}

View File

@@ -1,4 +1,7 @@
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
@@ -7,9 +10,43 @@ use halo2curves::ff::PrimeField;
use std::{
cell::RefCell,
collections::HashSet,
sync::atomic::{AtomicUsize, Ordering},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
lookup_index: usize,
col_coord: usize,
}
impl DynamicLookupIndex {
/// Create a new dynamic lookup index
pub fn new(lookup_index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex {
lookup_index,
col_coord,
}
}
/// Get the lookup index
pub fn lookup_index(&self) -> usize {
self.lookup_index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
@@ -56,6 +93,13 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -64,6 +108,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants += n;
}
///
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
self.dynamic_lookup_index.lookup_index += n;
}
///
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
self.dynamic_lookup_index.col_coord += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
@@ -75,6 +134,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context from a wrapped region
@@ -82,6 +148,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
@@ -90,11 +157,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -104,6 +182,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -111,8 +196,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
pub fn new_dummy_with_constants(
row: usize,
linear_coord: usize,
constants: usize,
total_constants: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -120,7 +209,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: constants,
total_constants,
dynamic_lookup_index,
used_lookups,
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -160,6 +256,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Create a new region context per loop iteration
/// hacky but it works
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
@@ -170,6 +267,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -177,12 +279,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
DynamicLookupIndex::default(),
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -195,6 +303,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.lookup_index += local_reg.dynamic_lookup_index.lookup_index;
dynamic_lookup_index.col_coord += local_reg.dynamic_lookup_index.col_coord;
res
})
.map_err(|e| {
@@ -203,7 +323,71 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
})?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
})?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
}
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
Ok(())
}
@@ -212,15 +396,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.region.is_none()
}
/// duplicate_dummy
pub fn duplicate_dummy(&self) -> Self {
Self {
region: None,
linear_coord: self.linear_coord,
num_inner_cols: self.num_inner_cols,
row: self.row,
total_constants: self.total_constants,
}
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
self.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
self.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
/// Get the offset
@@ -238,6 +427,41 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants
}
/// Get the dynamic lookup index
pub fn dynamic_lookup_index(&self) -> usize {
self.dynamic_lookup_index.lookup_index
}
/// Get the dynamic lookup column coordinate
pub fn dynamic_lookup_col_coord(&self) -> usize {
self.dynamic_lookup_index.col_coord
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
}
/// get used range checks
pub fn used_range_checks(&self) -> HashSet<Range> {
self.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
self.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i128 {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
@@ -262,6 +486,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
/// Assign a valtensor to a vartensor
pub fn assign_dynamic_lookup(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.dynamic_lookup_col_coord(),
values,
)
} else {
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,

View File

@@ -19,6 +19,9 @@ use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The safety factor offset for the number of rows in the lookup table.
@@ -91,7 +94,7 @@ pub struct Table<F: PrimeField> {
/// Flags if table has been previously assigned to.
pub is_assigned: bool,
/// Number of bits used in lookup table.
pub range: (i128, i128),
pub range: Range,
_marker: PhantomData<F>,
}
@@ -127,21 +130,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
///
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
range: (i128, i128),
range: Range,
logrows: usize,
nonlinearity: &LookupOp,
preexisting_inputs: Option<Vec<TableColumn>>,
@@ -149,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = Self::num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
@@ -257,3 +258,139 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
Ok(())
}
}
/// Halo2 range check column
#[derive(Clone, Debug)]
pub struct RangeCheck<F: PrimeField> {
/// Input to table.
pub inputs: Vec<TableColumn>,
/// col size
pub col_size: usize,
/// selector cn
pub selector_constructor: SelectorConstructor<F>,
/// Flags if table has been previously assigned to.
pub is_assigned: bool,
/// Number of bits used in lookup table.
pub range: Range,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
}
///
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
i128_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
let inputs = {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
}
cols
};
let num_cols = inputs.len();
if num_cols > 1 {
warn!("Using {} columns for range-check.", num_cols);
}
RangeCheck {
inputs,
col_size,
is_assigned: false,
selector_constructor: SelectorConstructor::new(num_cols),
range,
_marker: PhantomData,
}
}
/// Take a linear coordinate and output the (column, row) position in the storage block.
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) {
let x = linear_coord / self.col_size;
let y = linear_coord % self.col_size;
(x, y)
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
}
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
let col_multipliers: Vec<F> = (0..chunked_inputs.len())
.map(|x| self.selector_constructor.get_selector_val_at_idx(x))
.collect();
let _ = chunked_inputs
.enumerate()
.map(|(chunk_idx, inputs)| {
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(mut row_offset, input)| {
let col_multiplier = col_multipliers[chunk_idx];
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.inputs[x],
y,
|| Value::known(*input * col_multiplier),
)?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
},
)
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
}
}

View File

@@ -1,4 +1,3 @@
use crate::circuit::ops::hybrid::HybridOp;
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
@@ -90,7 +89,7 @@ mod matmul {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -165,7 +164,7 @@ mod matmul_col_overflow_double_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -239,14 +238,14 @@ mod matmul_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow_double_col {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -327,7 +326,7 @@ mod matmul_col_ultra_overflow_double_col {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -349,8 +348,13 @@ mod matmul_col_ultra_overflow_double_col {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -361,7 +365,7 @@ mod matmul_col_ultra_overflow_double_col {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -441,7 +445,7 @@ mod matmul_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -463,8 +467,13 @@ mod matmul_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -543,7 +552,7 @@ mod dot {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -620,7 +629,7 @@ mod dot_col_overflow_triple_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -693,7 +702,7 @@ mod dot_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -762,7 +771,7 @@ mod sum {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -832,7 +841,7 @@ mod sum_col_overflow_double_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -901,7 +910,7 @@ mod sum_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -994,7 +1003,7 @@ mod composition {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1095,7 +1104,7 @@ mod conv {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
#[test]
@@ -1133,14 +1142,14 @@ mod conv {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -1240,7 +1249,7 @@ mod conv_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -1262,8 +1271,13 @@ mod conv_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -1275,7 +1289,7 @@ mod conv_col_ultra_overflow {
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use super::*;
@@ -1390,7 +1404,7 @@ mod conv_relu_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -1412,8 +1426,13 @@ mod conv_relu_col_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
@@ -1484,7 +1503,7 @@ mod add_w_shape_casting {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1551,7 +1570,143 @@ mod add {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod dynamic_lookup {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
tables: [[ValTensor<F>; 2]; NUM_LOOP],
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
for _ in 0..NUM_LOOP {
config
.configure_dynamic_lookup(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
}
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.dynamic_lookup_col_coord(),
NUM_LOOP * self.tables[0][0].len()
);
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn dynamiclookupcircuit() {
// parameters
let tables = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.clone().try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..2).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from((0..2).map(|_| Value::known(F::from(10000))))),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
@@ -1618,7 +1773,7 @@ mod add_with_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1727,7 +1882,7 @@ mod add_with_overflow_and_poseidon {
let prover =
MockProver::run(K as u32, &circuit, vec![vec![commitment_a, commitment_b]]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
#[test]
@@ -1822,7 +1977,7 @@ mod sub {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1889,7 +2044,7 @@ mod mult {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1954,7 +2109,7 @@ mod pow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2023,7 +2178,7 @@ mod pack {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2116,145 +2271,7 @@ mod matmul_relu {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
}
}
#[cfg(test)]
mod rangecheckpercent {
use crate::circuit::Tolerance;
use crate::{circuit, tensor::Tensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const RANGE: f32 = 1.0; // 1 percent error tolerance
const K: usize = 18;
const LEN: usize = 1;
const SCALE: usize = i128::pow(2, 7) as usize;
use super::*;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
input: ValTensor<F>,
output: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let scale = utils::F32(SCALE.pow(2) as f32);
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// set up a new GreaterThan and Recip tables
let nl = &LookupOp::GreaterThan {
a: circuit::utils::F32((RANGE * scale.0) / 100.0),
};
config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
.unwrap();
config
.configure_lookup(
cs,
&b,
&output,
&a,
(-32768, 32768),
K,
&LookupOp::Recip { scale },
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&[self.output.clone(), self.input.clone()],
Box::new(HybridOp::RangeCheck(Tolerance {
val: RANGE,
scale: SCALE.into(),
})),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_range_check_percent() {
// Successful cases
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(101_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
}
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(199_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
}
// Unsuccessful case
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(102_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
match prover.verify() {
Ok(_) => {
assert!(false)
}
Err(_) => {
assert!(true)
}
}
}
prover.assert_satisfied();
}
}
@@ -2328,7 +2345,7 @@ mod relu {
};
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2339,7 +2356,7 @@ mod lookup_ultra_overflow {
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
poly::commitment::ParamsProver,
poly::commitment::{Params, ParamsProver},
};
#[derive(Clone)]
@@ -2421,7 +2438,7 @@ mod lookup_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ReLUCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -2443,120 +2460,16 @@ mod lookup_ultra_overflow {
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
mod softmax {
use super::*;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const K: usize = 18;
const LEN: usize = 3;
const SCALE: f32 = 128.0;
#[derive(Clone)]
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for SoftmaxCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Exp {
scale: SCALE.into(),
},
)
.unwrap();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Recip {
scale: SCALE.powf(2.0).into(),
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let _output = config
.layout(
&mut region,
&[self.input.clone()],
Box::new(HybridOp::Softmax {
scale: SCALE.into(),
axes: vec![0],
}),
)
.unwrap();
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn softmax_circuit() {
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = SoftmaxCircuit::<F> {
input: ValTensor::from(input),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
}
}

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, ValueEnum};
use clap::{Parser, Subcommand};
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
#[cfg(feature = "python-bindings")]
@@ -9,8 +9,9 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, RunArgs};
@@ -59,6 +60,8 @@ pub const DEFAULT_SOL_CODE_DA: &str = "evm_deploy_da.sol";
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
/// Default contract address for data attestation
pub const DEFAULT_CONTRACT_ADDRESS_DA: &str = "contract_da.address";
/// Default contract address for vk
pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
/// Default check mode
pub const DEFAULT_CHECKMODE: &str = "safe";
/// Default calibration target
@@ -73,15 +76,21 @@ pub const DEFAULT_FUZZ_RUNS: &str = "10";
pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
/// Default render vk seperately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VK abi path
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
impl IntoPy<PyObject> for TranscriptType {
@@ -126,17 +135,27 @@ impl Default for CalibrationTarget {
}
}
impl ToString for CalibrationTarget {
fn to_string(&self) -> String {
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
impl std::fmt::Display for CalibrationTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
)
}
}
impl ToFlags for CalibrationTarget {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
@@ -157,6 +176,36 @@ impl From<&str> for CalibrationTarget {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
inner: H160,
}
#[cfg(not(target_arch = "wasm32"))]
impl From<H160Flag> for H160 {
fn from(val: H160Flag) -> H160 {
val.inner
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ToFlags for H160Flag {
fn to_flags(&self) -> Vec<String> {
vec![format!("{:#x}", self.inner)]
}
}
#[cfg(not(target_arch = "wasm32"))]
impl From<&str> for H160Flag {
fn from(s: &str) -> Self {
Self {
inner: H160::from_str(s).unwrap(),
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
impl IntoPy<PyObject> for CalibrationTarget {
@@ -189,7 +238,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
@@ -230,7 +279,7 @@ impl Cli {
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
#[cfg(feature = "empty-cmd")]
/// Creates an empty buffer
@@ -313,9 +362,20 @@ pub enum Commands {
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
scales: Option<Vec<crate::Scale>>,
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
#[arg(
long,
value_delimiter = ',',
allow_hyphen_values = true,
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
)]
scale_rebase_multiplier: Vec<u32>,
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
},
/// Generates a dummy SRS
@@ -389,6 +449,9 @@ pub enum Commands {
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
/// Aggregates proofs :)
Aggregate {
@@ -408,7 +471,7 @@ pub enum Commands {
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -451,6 +514,9 @@ pub enum Commands {
/// The graph witness (optional - used to override fixed values in the circuit)
#[arg(short = 'W', long)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -460,24 +526,27 @@ pub enum Commands {
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
/// number of fuzz iterations
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEVMData {
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long)]
data: PathBuf,
@@ -505,7 +574,7 @@ pub enum Commands {
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long)]
addr: H160,
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long)]
data: PathBuf,
@@ -555,9 +624,9 @@ pub enum Commands {
check_mode: CheckMode,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEVMVerifier {
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -573,11 +642,36 @@ pub enum Commands {
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
abi_path: PathBuf,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VK_ABI)]
abi_path: PathBuf,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEVMDataAttestation {
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
@@ -597,9 +691,9 @@ pub enum Commands {
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEVMVerifierAggr {
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -618,6 +712,11 @@ pub enum Commands {
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
},
/// Verifies a proof, returning accept or reject
Verify {
@@ -633,6 +732,9 @@ pub enum Commands {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
@@ -669,6 +771,25 @@ pub enum Commands {
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVK {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
/// The path to output the contract address
addr_path: PathBuf,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
@@ -695,28 +816,23 @@ pub enum Commands {
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Verifies a proof using a local EVM executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEVM {
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
addr_verifier: H160,
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long)]
addr_da: Option<H160>,
},
/// Print the proof in hexadecimal
#[command(name = "print-proof-hex")]
PrintProofHex {
/// The path to the proof file
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
addr_vk: Option<H160Flag>,
},
}

View File

@@ -101,17 +101,18 @@ pub async fn setup_eth_backend(
}
///
pub async fn deploy_verifier_via_solidity(
pub async fn deploy_contract_via_solidity(
sol_code_path: PathBuf,
rpc_url: Option<&str>,
runs: usize,
private_key: Option<&str>,
contract_name: &str,
) -> Result<ethers::types::Address, Box<dyn Error>> {
// anvil instance must be alive at least until the factory completes the deploy
let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?;
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "Halo2Verifier", runs)?;
get_contract_artifacts(sol_code_path, contract_name, runs)?;
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?;
let contract = factory.deploy(())?.send().await?;
@@ -335,11 +336,16 @@ pub async fn update_account_calls(
pub async fn verify_proof_via_solidity(
proof: Snark<Fr, G1Affine>,
addr: ethers::types::Address,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
let flattened_instances = proof.instances.into_iter().flatten();
let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
let encoded = encode_calldata(
addr_vk.as_ref().map(|x| x.0),
&proof.proof,
&flattened_instances.collect::<Vec<_>>(),
);
info!("encoded: {:#?}", hex::encode(&encoded));
let (anvil, client) = setup_eth_backend(rpc_url, None).await?;
@@ -439,6 +445,7 @@ pub async fn verify_proof_with_data_attestation(
proof: Snark<Fr, G1Affine>,
addr_verifier: ethers::types::Address,
addr_da: ethers::types::Address,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
use ethers::abi::{Function, Param, ParamType, StateMutability, Token};
@@ -452,8 +459,11 @@ pub async fn verify_proof_with_data_attestation(
public_inputs.push(u);
}
let encoded_verifier =
encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
let encoded_verifier = encode_calldata(
addr_vk.as_ref().map(|x| x.0),
&proof.proof,
&flattened_instances.collect::<Vec<_>>(),
);
info!("encoded: {:#?}", hex::encode(&encoded_verifier));

View File

@@ -3,7 +3,9 @@ use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::commands::Commands;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
use crate::commands::H160Flag;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
#[allow(unused_imports)]
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
@@ -21,8 +23,6 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::poly::commitment::Params;
@@ -140,8 +140,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
compiled_circuit,
transcript,
num_runs,
} => fuzz(compiled_circuit, witness, transcript, num_runs),
compress_selectors,
} => fuzz(
compiled_circuit,
witness,
transcript,
num_runs,
compress_selectors,
),
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetSrs {
@@ -170,7 +176,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
only_range_check_rebase,
} => calibrate(
model,
data,
@@ -178,6 +186,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -192,28 +202,44 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMVerifier {
Commands::CreateEvmVerifier {
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
} => create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path),
render_vk_seperately,
} => create_evm_verifier(
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
),
Commands::CreateEvmVK {
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMDataAttestation {
Commands::CreateEvmDataAttestation {
settings_path,
sol_code_path,
abi_path,
data,
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMVerifierAggr {
Commands::CreateEvmVerifierAggr {
vk_path,
srs_path,
sol_code_path,
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
} => create_evm_aggregate_verifier(
vk_path,
srs_path,
@@ -221,6 +247,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
),
Commands::CompileCircuit {
model,
@@ -233,9 +260,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
pk_path,
witness,
} => setup(compiled_circuit, srs_path, vk_path, pk_path, witness),
compress_selectors,
} => setup(
compiled_circuit,
srs_path,
vk_path,
pk_path,
witness,
compress_selectors,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEVMData {
Commands::SetupTestEvmData {
data,
compiled_circuit,
test_data,
@@ -296,6 +331,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
} => setup_aggregate(
sample_snarks,
vk_path,
@@ -303,6 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
),
Commands::Aggregate {
proof_path,
@@ -329,7 +366,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path,
vk_path,
srs_path,
} => verify(proof_path, settings_path, vk_path, srs_path)
reduced_srs,
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::VerifyAggr {
proof_path,
@@ -352,6 +390,25 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
addr_path,
optimizer_runs,
private_key,
"Halo2Verifier",
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::DeployEvmVK {
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
} => {
deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2VerifyingKey",
)
.await
}
@@ -377,13 +434,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::VerifyEVM {
Commands::VerifyEvm {
proof_path,
addr_verifier,
rpc_url,
addr_da,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
addr_vk,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
}
}
@@ -431,8 +488,14 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let hash = sha256::digest(&std::fs::read(path.clone())?);
let file = std::fs::File::open(path.clone())?;
let mut buffer = vec![];
let bytes_read = std::io::BufReader::new(file).read_to_end(&mut buffer)?;
debug!("read {} bytes from SRS file", bytes_read);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
@@ -440,7 +503,7 @@ fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
if hash != predefined_hash.to_string() {
if hash != *predefined_hash {
// delete file
warn!("removing SRS file at {}", path.display());
std::fs::remove_file(path)?;
@@ -555,7 +618,7 @@ pub(crate) async fn gen_witness(
let start_time = Instant::now();
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?;
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?;
// print each variable tuple (symbol, value) as symbol=value
trace!(
@@ -571,8 +634,12 @@ pub(crate) async fn gen_witness(
);
if let Some(output_path) = output {
serde_json::to_writer(&File::create(output_path)?, &witness)?;
witness.save(output_path)?;
}
// print the witness in debug
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);
Ok(witness)
}
@@ -662,15 +729,14 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error =
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error.into_iter());
abs_errors.extend(abs_error.into_iter());
squared_errors.extend(squared_error.into_iter());
percentage_errors.extend(percentage_error.into_iter());
abs_percentage_errors.extend(abs_percentage_error.into_iter());
errors.extend(error);
abs_errors.extend(abs_error);
squared_errors.extend(squared_error);
percentage_errors.extend(percentage_error);
abs_percentage_errors.extend(abs_percentage_error);
}
let mean_percent_error =
@@ -679,29 +745,25 @@ impl AccuracyResults {
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
let median_error = errors[errors.len() / 2];
let max_error = errors
let max_error = *errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_error = errors
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_error = *errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
let median_abs_error = abs_errors[abs_errors.len() / 2];
let max_abs_error = abs_errors
let max_abs_error = *abs_errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_abs_error = abs_errors
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_abs_error = *abs_errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
@@ -724,6 +786,7 @@ impl AccuracyResults {
/// Calibrate the circuit parameters to a given a dataset
#[cfg(not(target_arch = "wasm32"))]
#[allow(trivial_casts)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn calibrate(
model_path: PathBuf,
data: PathBuf,
@@ -731,6 +794,8 @@ pub(crate) fn calibrate(
target: CalibrationTarget,
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -743,16 +808,7 @@ pub(crate) fn calibrate(
// we load the model to get the input and output shapes
// check if gag already exists
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
let model = Model::from_run_args(&settings.run_args, &model_path)?;
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
@@ -768,16 +824,17 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
match target {
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
}
(11..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if only_range_check_rebase {
vec![false]
} else {
vec![true, false]
};
let mut found_params: Vec<GraphSettings> = vec![];
let scale_rebase_multiplier = [1, 2, 10];
// 2 x 2 grid
let range_grid = range
.iter()
@@ -813,28 +870,23 @@ pub(crate) fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
let mut forward_pass_res = HashMap::new();
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
input_scale, param_scale, scale_rebase_multiplier
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
};
#[cfg(unix)]
let _q = match Gag::stderr() {
Ok(r) => Some(r),
Err(_) => None,
};
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
@@ -842,15 +894,19 @@ pub(crate) fn calibrate(
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
..settings.run_args.clone()
};
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(_) => return Err(format!("failed to create circuit from run args").into()),
Err(e) => {
debug!("circuit creation from run args failed: {:?}", e);
continue;
}
};
chunks
let forward_res = chunks
.iter()
.map(|chunk| {
let chunk = chunk.clone();
@@ -860,7 +916,7 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.forward(&mut data.clone(), None, None)
.forward(&mut data.clone(), None, None, true)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
@@ -871,41 +927,55 @@ pub(crate) fn calibrate(
Ok(()) as Result<(), String>
})
.collect::<Result<Vec<()>, String>>()?;
.collect::<Result<Vec<()>, String>>();
match forward_res {
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
continue;
}
}
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.1.iter().map(|x| x.min_lookup_inputs))
.flatten()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.1.iter().map(|x| x.max_lookup_inputs))
.flatten()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
let res = circuit.calibrate_from_min_max(
min_lookup_range,
max_lookup_range,
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let res = circuit.calc_min_logrows(
(min_lookup_range, max_lookup_range),
max_range_size,
max_logrows,
lookup_safety_margin,
);
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
if res.is_ok() {
let new_settings = circuit.settings().clone();
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
@@ -915,6 +985,7 @@ pub(crate) fn calibrate(
let found_settings = GraphSettings {
run_args: found_run_args,
required_lookups: new_settings.required_lookups,
required_range_checks: new_settings.required_range_checks,
model_output_scales: new_settings.model_output_scales,
model_input_scales: new_settings.model_input_scales,
num_rows: new_settings.num_rows,
@@ -930,7 +1001,7 @@ pub(crate) fn calibrate(
found_settings.as_json()?.to_colored_json_auto()?
);
} else {
debug!("calibration failed");
debug!("calibration failed {}", res.err().unwrap());
}
pb.inc(1);
@@ -1023,7 +1094,7 @@ pub(crate) fn calibrate(
let tear_sheet_table = Table::new(vec![accuracy_res]);
println!(
warn!(
"\n\n <------------- Numerical Fidelity Report (input_scale: {}, param_scale: {}, scale_input_multiplier: {}) ------------->\n\n{}\n\n",
best_params.run_args.input_scale,
best_params.run_args.param_scale,
@@ -1087,21 +1158,11 @@ pub(crate) fn mock(
)
.map_err(Box::<dyn Error>::from)?;
prover
.verify_par()
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
Ok(String::new())
}
pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<String, Box<dyn Error>> {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
for instance in proof.instances {
println!("{:?}", instance);
}
let hex_str = hex::encode(proof.proof);
info!("0x{}", hex_str);
Ok(format!("0x{}", hex_str))
}
#[cfg(feature = "render")]
pub(crate) fn render(
model: PathBuf,
@@ -1132,6 +1193,7 @@ pub(crate) fn create_evm_verifier(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
@@ -1149,7 +1211,11 @@ pub(crate) fn create_evm_verifier(
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
);
let verifier_solidity = generator.render()?;
let verifier_solidity = if render_vk_seperately {
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
} else {
generator.render()?
};
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
@@ -1161,6 +1227,43 @@ pub(crate) fn create_evm_verifier(
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_vk(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
let num_instance = circuit_settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
&params,
&vk,
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
);
let vk_solidity = generator.render_separately()?.1;
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_data_attestation(
settings_path: PathBuf,
@@ -1256,13 +1359,15 @@ pub(crate) async fn deploy_evm(
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
contract_name: &str,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let contract_address = deploy_verifier_via_solidity(
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
)
.await?;
@@ -1276,9 +1381,10 @@ pub(crate) async fn deploy_evm(
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160,
addr_verifier: H160Flag,
rpc_url: Option<String>,
addr_da: Option<H160>,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::verify_proof_with_data_attestation;
check_solc_requirement();
@@ -1288,13 +1394,20 @@ pub(crate) async fn verify_evm(
let result = if let Some(addr_da) = addr_da {
verify_proof_with_data_attestation(
proof.clone(),
addr_verifier,
addr_da,
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
} else {
verify_proof_via_solidity(proof.clone(), addr_verifier, rpc_url.as_deref()).await?
verify_proof_via_solidity(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
};
info!("Solidity verification result: {}", result);
@@ -1314,6 +1427,7 @@ pub(crate) fn create_evm_aggregate_verifier(
abi_path: PathBuf,
circuit_settings: Vec<PathBuf>,
logrows: u32,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let srs_path = get_srs_path(logrows, srs_path);
@@ -1352,7 +1466,11 @@ pub(crate) fn create_evm_aggregate_verifier(
generator = generator.set_acc_encoding(Some(acc_encoding));
let verifier_solidity = generator.render()?;
let verifier_solidity = if render_vk_seperately {
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
} else {
generator.render()?
};
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
@@ -1381,6 +1499,7 @@ pub(crate) fn setup(
vk_path: PathBuf,
pk_path: PathBuf,
witness: Option<PathBuf>,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit)?;
@@ -1391,8 +1510,12 @@ pub(crate) fn setup(
let params = load_params_cmd(srs_path, circuit.settings().run_args.logrows)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(Box::<dyn Error>::from)?;
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_path, pk.get_vk())?;
save_pk::<KZGCommitmentScheme<Bn256>>(&pk_path, &pk)?;
@@ -1440,14 +1563,14 @@ pub(crate) async fn setup_test_evm_witness(
use crate::pfsys::ProofType;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn test_update_account_calls(
addr: H160,
addr: H160Flag,
data: PathBuf,
rpc_url: Option<String>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::update_account_calls;
check_solc_requirement();
update_account_calls(addr, data, rpc_url.as_deref()).await?;
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
Ok(String::new())
}
@@ -1531,6 +1654,7 @@ pub(crate) fn fuzz(
data_path: PathBuf,
transcript: TranscriptType,
num_runs: usize,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let passed = AtomicBool::new(true);
@@ -1546,8 +1670,12 @@ pub(crate) fn fuzz(
let data = GraphWitness::from_path(data_path)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(Box::<dyn Error>::from)?;
circuit.load_graph_witness(&data)?;
@@ -1563,9 +1691,12 @@ pub(crate) fn fuzz(
let fuzz_pk = || {
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
let bad_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
.map_err(|_| ())?;
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
)
.map_err(|_| ())?;
let bad_proof = create_proof_circuit_kzg(
circuit.clone(),
@@ -1584,6 +1715,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1614,6 +1746,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1636,9 +1769,12 @@ pub(crate) fn fuzz(
let fuzz_vk = || {
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
let bad_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
.map_err(|_| ())?;
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
)
.map_err(|_| ())?;
let bad_vk = bad_pk.get_vk();
@@ -1647,6 +1783,7 @@ pub(crate) fn fuzz(
proof.clone(),
bad_vk,
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1678,6 +1815,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1713,6 +1851,7 @@ pub(crate) fn fuzz(
bad_proof,
pk.get_vk(),
strategy.clone(),
params.n(),
)
.map_err(|_| ())
};
@@ -1798,7 +1937,7 @@ pub(crate) fn mock_aggregate(
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
.map_err(Box::<dyn Error>::from)?;
prover
.verify_par()
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("Done.");
@@ -1812,6 +1951,7 @@ pub(crate) fn setup_aggregate(
srs_path: Option<PathBuf>,
logrows: u32,
split_proofs: bool,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
// the K used for the aggregation circuit
let params = load_params_cmd(srs_path, logrows)?;
@@ -1822,8 +1962,11 @@ pub(crate) fn setup_aggregate(
}
let agg_circuit = AggregationCircuit::new(&params.get_g()[0].into(), snarks, split_proofs)?;
let agg_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(&agg_circuit, &params)?;
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
&agg_circuit,
&params,
compress_selectors,
)?;
let agg_vk = agg_pk.get_vk();
@@ -1894,15 +2037,29 @@ pub(crate) fn verify(
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
reduced_srs: bool,
) -> Result<bool, Box<dyn Error>> {
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
let params = if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
};
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
let vk =
load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?;
let now = Instant::now();
let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
let result = verify_proof_circuit_kzg(
params.verifier_params(),
proof,
&vk,
strategy,
1 << circuit_settings.run_args.logrows,
);
let elapsed = now.elapsed();
info!(
"verify took {}.{}",
@@ -1926,7 +2083,7 @@ pub(crate) fn verify_aggr(
let strategy = AccumulatorStrategy::new(params.verifier_params());
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(vk_path, ())?;
let now = Instant::now();
let result = verify_proof_circuit_kzg(&params, proof, &vk, strategy);
let result = verify_proof_circuit_kzg(&params, proof, &vk, strategy, 1 << logrows);
let elapsed = now.elapsed();
info!(

View File

@@ -4,6 +4,7 @@ use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use postgres::{Client, NoTls};
@@ -15,6 +16,8 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
@@ -490,16 +493,20 @@ impl GraphData {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to open input at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let reader = std::fs::File::open(path)?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, self)?;
Ok(())
}
///
@@ -617,13 +624,13 @@ impl ToPyObject for DataSource {
}
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_vecu64_montgomery;
use crate::pfsys::field_to_string;
#[cfg(feature = "python-bindings")]
impl ToPyObject for FileSourceInner {
fn to_object(&self, py: Python) -> PyObject {
match self {
FileSourceInner::Field(data) => field_to_vecu64_montgomery(data).to_object(py),
FileSourceInner::Field(data) => field_to_string(data).to_object(py),
FileSourceInner::Bool(data) => data.to_object(py),
FileSourceInner::Float(data) => data.to_object(py),
}

View File

@@ -16,6 +16,7 @@ use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSource;
@@ -23,19 +24,22 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::RunArgs;
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use log::{debug, error, info, trace, warn};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
@@ -46,20 +50,36 @@ use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
pub use vars::*;
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_vecu64_montgomery;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
/// Max circuit area
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
if let Ok(max_circuit_area) = std::env::var("EZKL_MAX_CIRCUIT_AREA") {
Some(max_circuit_area.parse().unwrap_or(0))
} else {
None
};
}
#[cfg(target_arch = "wasm32")]
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
/// circuit related errors.
#[derive(Debug, Error)]
@@ -117,15 +137,16 @@ pub enum GraphError {
MissingResults,
}
const ASSUMED_BLINDING_FACTORS: usize = 5;
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
pub const MIN_LOGROWS: u32 = 6;
/// 26
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
/// Lookup deg
pub const LOOKUP_DEG: usize = 5;
///
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
use std::cell::RefCell;
@@ -154,6 +175,8 @@ pub struct GraphWitness {
pub max_lookup_inputs: i128,
/// max lookup input
pub min_lookup_inputs: i128,
/// max range check size
pub max_range_size: i128,
}
impl GraphWitness {
@@ -181,6 +204,7 @@ impl GraphWitness {
processed_outputs: None,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
}
}
@@ -191,42 +215,41 @@ impl GraphWitness {
output_scales: Vec<crate::Scale>,
visibility: VarVisibility,
) {
let mut pretty_elements = PrettyElements::default();
pretty_elements.rescaled_inputs = self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.inputs = self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
pretty_elements.rescaled_outputs = self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.outputs = self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
let mut pretty_elements = PrettyElements {
rescaled_inputs: self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
inputs: self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
rescaled_outputs: self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
outputs: self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
..Default::default()
};
if let Some(processed_inputs) = self.processed_inputs.clone() {
pretty_elements.processed_inputs = processed_inputs
@@ -292,16 +315,20 @@ impl GraphWitness {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load model at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
///
@@ -332,22 +359,26 @@ impl ToPyObject for GraphWitness {
let dict_params = PyDict::new(py);
let dict_outputs = PyDict::new(py);
let inputs: Vec<Vec<[u64; 4]>> = self
let inputs: Vec<Vec<String>> = self
.inputs
.iter()
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
let outputs: Vec<Vec<[u64; 4]>> = self
let outputs: Vec<Vec<String>> = self
.outputs
.iter()
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
dict.set_item("inputs", inputs).unwrap();
dict.set_item("outputs", outputs).unwrap();
dict.set_item("max_lookup_inputs", self.max_lookup_inputs)
.unwrap();
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
.unwrap();
dict.set_item("max_range_size", self.max_range_size)
.unwrap();
if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
@@ -389,10 +420,7 @@ impl ToPyObject for GraphWitness {
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
let poseidon_hash: Vec<[u64; 4]> = poseidon_hash
.iter()
.map(field_to_vecu64_montgomery)
.collect();
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
Ok(())
@@ -421,6 +449,10 @@ pub struct GraphSettings {
pub total_assignments: usize,
/// total const size
pub total_const_size: usize,
/// total dynamic column size
pub total_dynamic_col_size: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -431,6 +463,8 @@ pub struct GraphSettings {
pub module_sizes: ModuleSizes,
/// required_lookups
pub required_lookups: Vec<LookupOp>,
/// required range_checks
pub required_range_checks: Vec<Range>,
/// check mode
pub check_mode: CheckMode,
/// ezkl version used
@@ -442,6 +476,24 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64).log2().ceil() as u32
}
fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64).log2().ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self
@@ -454,22 +506,33 @@ impl GraphSettings {
instances
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>();
// max between 1 and the log2 of the sums
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// save params to file
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
let encoded = serde_json::to_string(&self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(encoded.as_bytes())
// buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// load params from file
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
let mut file = std::fs::File::open(path).map_err(|e| {
error!("failed to open settings file at {}", e);
e
})?;
let mut data = String::new();
file.read_to_string(&mut data)?;
let res = serde_json::from_str(&data)?;
Ok(res)
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// Export the ezkl configuration as json
@@ -515,6 +578,11 @@ impl GraphSettings {
|| self.run_args.param_visibility.is_hashed()
}
/// requires dynamic lookup
pub fn requires_dynamic_lookup(&self) -> bool {
self.num_dynamic_lookups > 0
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
@@ -528,6 +596,7 @@ impl GraphSettings {
pub struct GraphConfig {
model_config: ModelConfig,
module_configs: ModuleConfigs,
circuit_size: CircuitSize,
}
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
@@ -564,7 +633,7 @@ impl GraphCircuit {
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(f);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
@@ -572,11 +641,10 @@ impl GraphCircuit {
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
// read bytes from file
let mut f = std::fs::File::open(&path)?;
let metadata = std::fs::metadata(&path)?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer)?;
let result = bincode::deserialize(&buffer)?;
let f = std::fs::File::open(path)?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
Ok(result)
}
}
@@ -591,6 +659,17 @@ pub enum TestDataSource {
OnChain,
}
impl std::fmt::Display for TestDataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestDataSource::File => write!(f, "file"),
TestDataSource::OnChain => write!(f, "on-chain"),
}
}
}
impl ToFlags for TestDataSource {}
impl From<String> for TestDataSource {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -639,7 +718,7 @@ impl GraphCircuit {
}
// dummy module settings, must load from GraphData after
let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?;
let mut settings = model.gen_params(run_args, run_args.check_mode)?;
let mut num_params = 0;
if !model.const_shapes().is_empty() {
@@ -763,18 +842,18 @@ impl GraphCircuit {
if self.settings().run_args.input_visibility.is_public() {
public_inputs.rescaled_inputs = elements.rescaled_inputs.clone();
public_inputs.inputs = elements.inputs.clone();
} else if let Some(_) = &data.processed_inputs {
} else if data.processed_inputs.is_some() {
public_inputs.processed_inputs = elements.processed_inputs.clone();
}
if let Some(_) = &data.processed_params {
if data.processed_params.is_some() {
public_inputs.processed_params = elements.processed_params.clone();
}
if self.settings().run_args.output_visibility.is_public() {
public_inputs.rescaled_outputs = elements.rescaled_outputs.clone();
public_inputs.outputs = elements.outputs.clone();
} else if let Some(_) = &data.processed_outputs {
} else if data.processed_outputs.is_some() {
public_inputs.processed_outputs = elements.processed_outputs.clone();
}
@@ -807,7 +886,7 @@ impl GraphCircuit {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
info!("input scales: {:?}", scales);
debug!("input scales: {:?}", scales);
match &data.input_data {
DataSource::File(file_data) => {
@@ -826,7 +905,7 @@ impl GraphCircuit {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
info!("input scales: {:?}", scales);
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
@@ -952,38 +1031,47 @@ impl GraphCircuit {
Ok(data)
}
fn reserved_blinding_rows() -> f64 {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(
min_lookup_inputs: i128,
max_lookup_inputs: i128,
lookup_safety_margin: i128,
) -> (i128, i128) {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
lookup_safety_margin * min_lookup_inputs,
lookup_safety_margin * max_lookup_inputs,
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
if lookup_safety_margin == 1 {
margin.0 -= 1;
margin.1 += 1;
margin.0 += 4;
margin.1 += 4;
}
margin
}
fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
);
Table::<Fp>::num_cols_required(safe_range, max_col_size)
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
fn calc_min_logrows(
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
max_range_size,
);
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
.log2()
.ceil() as u32;
Ok(min_bits)
}
/// calculate the minimum logrows required for the circuit
pub fn calc_min_logrows(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
min_max_lookup: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -993,126 +1081,91 @@ impl GraphCircuit {
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
let mut min_logrows = MIN_LOGROWS;
let reserved_blinding_rows = Self::reserved_blinding_rows();
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if has overflowed max lookup input
if max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
{
let err_string = format!("max lookup input ({}) is too large", max_lookup_inputs);
error!("{}", err_string);
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
}
let safe_range = Self::calc_safe_lookup_range(
min_lookup_inputs,
max_lookup_inputs,
lookup_safety_margin,
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
}
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
*[
instance_logrows,
module_constraint_logrows,
dynamic_lookup_logrows,
]
.iter()
.max()
.unwrap(),
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
*[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap(),
);
// we now have a min and max logrows
max_logrows = std::cmp::max(min_logrows, max_logrows);
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
min_logrows,
Self::calc_num_cols(safe_range, min_logrows),
)
{
min_logrows += 1;
}
if !self
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
{
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
error!("{}", err_string);
return Err(err_string.into());
}
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
max_logrows,
Self::calc_num_cols(safe_range, max_logrows),
)
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
{
max_logrows -= 1;
}
if !self
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
{
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
error!("{}", err_string);
debug!("{}", err_string);
return Err(err_string.into());
}
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
.log2()
.ceil() as usize;
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as usize;
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
// if public input then public inputs col will have public inputs len
if self.settings().run_args.input_visibility.is_public()
|| self.settings().run_args.output_visibility.is_public()
{
let mut max_instance_len = self
.model()
.instance_shapes()?
.iter()
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
as f64
+ reserved_blinding_rows;
// if there are modules then we need to add the max module size
if self.settings().uses_modules() {
max_instance_len += self
.settings()
.module_sizes
.num_instances()
.iter()
.sum::<usize>() as f64;
}
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
logrows = std::cmp::max(logrows, instance_len_logrows);
// this is for fixed const columns
}
// ensure logrows is at least 4
logrows = std::cmp::max(logrows, min_logrows as usize);
logrows = std::cmp::min(logrows, max_logrows as usize);
let logrows = max_logrows;
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_range;
settings_mut.run_args.logrows = logrows as u32;
settings_mut.run_args.lookup_range = safe_lookup_range;
settings_mut.run_args.logrows = logrows;
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
.settings()
.clone();
// recalculate the total const size give nthe new logrows
let total_const_len = settings_mut.total_const_size;
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
// recalculate the total number of constraints given the new logrows
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
// recalculate the logrows if there has been overflow on the constants
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.constants_logrows(),
);
// recalculate the logrows if there has been overflow for the model constraints
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.model_constraint_logrows(),
);
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
info!(
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
self.settings().run_args.lookup_range,
self.settings().run_args.logrows
@@ -1121,12 +1174,29 @@ impl GraphCircuit {
Ok(())
}
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
let max_degree = self.settings().run_args.num_inner_cols + 2;
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
fn extended_k_is_small_enough(
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
return false;
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
return false;
}
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
let mut settings = self.settings().clone();
settings.run_args.lookup_range = safe_lookup_range;
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
Self::configure_with_params(&mut cs, settings);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
let max_degree = cs.degree();
let quotient_poly_degree = (max_degree - 1) as u64;
// n = 2^k
let n = 1u64 << k;
@@ -1134,36 +1204,20 @@ impl GraphCircuit {
while (1 << extended_k) < (n * quotient_poly_degree) {
extended_k += 1;
if !(extended_k <= bn256::Fr::S) {
if extended_k > bn256::Fr::S {
return false;
}
}
true
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate_from_min_max(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_lookup_inputs,
max_lookup_inputs,
max_logrows,
lookup_safety_margin,
)?;
Ok(())
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
pub fn forward(
&self,
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
throw_range_check_error: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1204,7 +1258,9 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs)?;
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1247,6 +1303,7 @@ impl GraphCircuit {
processed_outputs,
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_size: model_results.max_range_size,
};
witness.generate_rescaled_elements(
@@ -1364,7 +1421,6 @@ impl GraphCircuit {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
num_instances: usize,
@@ -1372,20 +1428,22 @@ struct CircuitSize {
num_fixed: usize,
num_challenges: usize,
num_selectors: usize,
logrows: u32,
}
#[cfg(not(target_arch = "wasm32"))]
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>) -> Self {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
num_fixed: cs.num_fixed_columns(),
num_challenges: cs.num_challenges(),
num_selectors: cs.num_selectors(),
logrows,
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
@@ -1396,6 +1454,25 @@ impl CircuitSize {
};
Ok(serialized)
}
/// number of columns
pub fn num_columns(&self) -> usize {
self.num_instances + self.num_advice_columns + self.num_fixed
}
/// area of the circuit
pub fn area(&self) -> usize {
self.num_columns() * (1 << self.logrows)
}
/// area less than max
pub fn area_less_than_max(&self) -> bool {
if EZKL_MAX_CIRCUIT_AREA.is_some() {
self.area() < EZKL_MAX_CIRCUIT_AREA.unwrap()
} else {
true
}
}
}
impl Circuit<Fp> for GraphCircuit {
@@ -1433,33 +1510,18 @@ impl Circuit<Fp> for GraphCircuit {
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(
cs,
params.run_args.logrows as usize,
params.total_assignments,
params.run_args.num_inner_cols,
params.total_const_size,
params.module_requires_fixed(),
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
vars.instantiate_instance(
cs,
params.model_instance_shapes,
params.model_instance_shapes.clone(),
params.run_args.input_scale,
module_configs.instance,
);
let base = Model::configure(
cs,
&vars,
params.run_args.lookup_range,
params.run_args.logrows as usize,
params.required_lookups,
params.check_mode,
)
.unwrap();
let base = Model::configure(cs, &vars, &params).unwrap();
let model_config = ModelConfig { base, vars };
@@ -1469,10 +1531,12 @@ impl Circuit<Fp> for GraphCircuit {
(cs.degree() as f32).log2().ceil()
);
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"circuit size: \n {}",
CircuitSize::from_cs(cs)
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
@@ -1482,6 +1546,7 @@ impl Circuit<Fp> for GraphCircuit {
GraphConfig {
model_config,
module_configs,
circuit_size,
}
}
@@ -1494,6 +1559,16 @@ impl Circuit<Fp> for GraphCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fp>,
) -> Result<(), PlonkError> {
// check if the circuit area is less than the max
if !config.circuit_size.area_less_than_max() {
error!(
"circuit area {} is larger than the max allowed area {}",
config.circuit_size.area(),
EZKL_MAX_CIRCUIT_AREA.unwrap()
);
return Err(PlonkError::Synthesis);
}
trace!("Setting input in synthesize");
let input_vis = &self.settings().run_args.input_visibility;
let output_vis = &self.settings().run_args.output_visibility;

View File

@@ -6,10 +6,10 @@ use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
@@ -56,6 +56,8 @@ use unzip_n::unzip_n;
unzip_n!(pub 3);
#[cfg(not(target_arch = "wasm32"))]
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
/// The result of a forward pass.
#[derive(Clone, Debug)]
pub struct ForwardResult {
@@ -65,6 +67,19 @@ pub struct ForwardResult {
pub max_lookup_inputs: i128,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
/// The max range check size
pub max_range_size: i128,
}
impl From<DummyPassRes> for ForwardResult {
fn from(res: DummyPassRes) -> Self {
Self {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
max_range_size: res.max_range_size,
}
}
}
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
@@ -79,6 +94,33 @@ pub struct ModelConfig {
/// Representation of execution graph
pub type NodeGraph = BTreeMap<usize, NodeType>;
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct DummyPassRes {
/// number of rows use
pub num_rows: usize,
/// num dynamic lookups
pub num_dynamic_lookups: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
pub total_const_size: usize,
/// lookup ops
pub lookup_ops: HashSet<LookupOp>,
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i128,
/// min lookup inputs
pub min_lookup_inputs: i128,
/// min range check
pub max_range_size: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model {
@@ -233,13 +275,7 @@ impl NodeType {
NodeType::SubGraph { out_dims, .. } => out_dims.clone(),
}
}
/// Returns the lookups required by a graph
pub fn required_lookups(&self) -> Vec<LookupOp> {
match self {
NodeType::Node(n) => n.opkind.required_lookups(),
NodeType::SubGraph { model, .. } => model.required_lookups(),
}
}
/// Returns the scales of the node's output.
pub fn out_scales(&self) -> Vec<crate::Scale> {
match self {
@@ -424,14 +460,6 @@ impl ParsedNodes {
}
impl Model {
fn required_lookups(&self) -> Vec<LookupOp> {
self.graph
.nodes
.values()
.flat_map(|n| n.required_lookups())
.collect_vec()
}
/// Creates a `Model` from a specified path to an Onnx file.
/// # Arguments
/// * `reader` - A reader for an Onnx file.
@@ -476,7 +504,7 @@ impl Model {
) -> Result<GraphSettings, Box<dyn Error>> {
let instance_shapes = self.instance_shapes()?;
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {}",
"model has".blue(),
instance_shapes.len().to_string().blue(),
@@ -484,36 +512,41 @@ impl Model {
);
// this is the total number of variables we will need to allocate
// for the circuit
let (num_rows, linear_coord, total_const_size) =
self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
// extract the requisite lookup ops from the model
let mut lookup_ops: Vec<LookupOp> = self.required_lookups();
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
if run_args.tolerance.val > 0.0 {
for scale in self.graph.get_output_scales()? {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(scale).into();
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
lookup_ops.extend(opkind.required_lookups());
}
}
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
lookup_ops.extend(set.into_iter().sorted());
Ok(GraphSettings {
run_args: run_args.clone(),
model_instance_shapes: instance_shapes,
module_sizes: crate::graph::modules::ModuleSizes::default(),
num_rows,
total_assignments: linear_coord,
required_lookups: lookup_ops,
num_rows: res.num_rows,
total_assignments: res.linear_coord,
required_lookups: res.lookup_ops.into_iter().collect(),
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
total_const_size,
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
num_blinding_factors: None,
@@ -534,205 +567,18 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
let mut max_lookup_inputs = 0;
let mut min_lookup_inputs = 0;
let input_shapes = self.graph.input_shapes()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
let mut input = model_inputs[i].clone();
input.reshape(&input_shapes[i])?;
results.insert(input_idx, vec![input]);
}
for (idx, n) in self.graph.nodes.iter() {
let mut inputs = vec![];
if n.is_input() {
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
inputs.push(t);
} else {
for (idx, outlet) in n.inputs().iter() {
match results.get(&idx) {
Some(value) => inputs.push(value[*outlet].clone()),
None => return Err(Box::new(GraphError::MissingNode(*idx))),
}
}
};
debug!("executing {}: {}", idx, n.as_str());
debug!("dims: {:?}", n.out_dims());
debug!(
"input_dims: {:?}",
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);
if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {
max = max.max(
i.iter()
.map(|x| felt_to_i128(*x))
.max()
.ok_or("missing max")?,
);
min = min.min(
i.iter()
.map(|x| felt_to_i128(*x))
.min()
.ok_or("missing min")?,
);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("max lookup inputs: {}", max);
debug!("min lookup inputs: {}", min);
}
match n {
NodeType::Node(n) => {
// execute the op
let start = instant::Instant::now();
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
res.output.reshape(&n.out_dims)?;
let elapsed = start.elapsed();
trace!("op took: {:?}", elapsed);
// see if any of the intermediate lookup calcs are the max
if !res.intermediate_lookups.is_empty() {
let (mut min, mut max) = (0, 0);
for i in &res.intermediate_lookups {
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("intermediate max lookup inputs: {}", max);
debug!("intermediate min lookup inputs: {}", min);
}
debug!(
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
idx,
res.output.map(crate::fieldutils::felt_to_i32).show(),
res.output
.map(|x| crate::fieldutils::felt_to_f64(x)
/ scale_to_multiplier(n.out_scale))
.show(),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
n.out_scale
);
results.insert(idx, vec![res.output]);
}
NodeType::SubGraph {
model,
output_mappings,
input_mappings,
inputs: input_tuple,
..
} => {
let orig_inputs = inputs.clone();
let input_mappings = input_mappings.clone();
let input_dims = inputs.iter().map(|inp| inp.dims());
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, input_tuple, model.graph.inputs
);
debug!("input_mappings: {:?}", input_mappings);
let mut full_results: Vec<Tensor<Fp>> = vec![];
for i in 0..num_iter {
// replace the Stacked input with the current chunk iter
for ((mapping, inp), og_input) in
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
{
if let InputMapping::Stacked { axis, chunk } = mapping {
let start = i * chunk;
let end = (i + 1) * chunk;
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
*inp = t;
}
}
let res = model.forward(&inputs)?;
// recursively get the max lookup inputs for subgraphs
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
let mut outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
for mapping in mappings {
match mapping {
OutputMapping::Single { outlet, .. } => {
outlets.insert(outlet, outlet_res.clone());
}
OutputMapping::Stacked { outlet, axis, .. } => {
if !full_results.is_empty() {
let stacked_res = crate::tensor::ops::concat(
&[&full_results[*outlet], &outlet_res],
*axis,
)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
}
full_results = outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
inputs[*input_idx] = full_results[output_idx].clone();
}
}
trace!(
"------------ output subgraph node {}: {:?}",
idx,
full_results
.iter()
.map(|x|
// convert to tensor i32
x.map(crate::fieldutils::felt_to_i32).show())
.collect_vec()
);
results.insert(idx, full_results);
}
}
}
let output_nodes = self.graph.outputs.iter();
debug!(
"model outputs are nodes: {:?}",
output_nodes.clone().collect_vec()
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
let res = ForwardResult {
outputs,
max_lookup_inputs,
min_lookup_inputs,
};
Ok(res)
pub fn forward(
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
Ok(res.into())
}
/// Loads an Onnx model from a specified path.
@@ -744,7 +590,7 @@ impl Model {
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
) -> Result<TractResult, Box<dyn Error>> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
@@ -783,7 +629,7 @@ impl Model {
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
info!("set {} to {}", symbol, value);
debug!("set {} to {}", symbol, value);
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
@@ -1042,6 +888,8 @@ impl Model {
&run_args.param_visibility,
i,
symbol_values,
run_args.div_rebasing,
run_args.rebase_frac_zero_constants,
)?;
if let Some(ref scales) = override_input_scales {
if let Some(inp) = n.opkind.get_input() {
@@ -1058,9 +906,20 @@ impl Model {
if scales.contains_key(&i) {
let scale_diff = n.out_scale - scales[&i];
n.opkind = if scale_diff > 0 {
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
RebaseScale::rebase(
n.opkind,
scales[&i],
n.out_scale,
1,
run_args.div_rebasing,
)
} else {
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
RebaseScale::rebase_up(
n.opkind,
scales[&i],
n.out_scale,
run_args.div_rebasing,
)
};
n.out_scale = scales[&i];
}
@@ -1150,32 +1009,46 @@ impl Model {
/// # Arguments
/// * `meta` - The constraint system.
/// * `vars` - The variables for the circuit.
/// * `run_args` - [RunArgs]
/// * `required_lookups` - The required lookup operations for the circuit.
/// * `settings` - [GraphSettings]
pub fn configure(
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
lookup_range: (i128, i128),
logrows: usize,
required_lookups: Vec<LookupOp>,
check_mode: CheckMode,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
info!("configuring model");
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_dynamic_lookups = settings.num_dynamic_lookups;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
&vars.advices[2],
check_mode,
settings.check_mode,
);
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
let input = &vars.advices[0];
let output = &vars.advices[1];
let index = &vars.advices[2];
let output = &vars.advices[2];
let index = &vars.advices[1];
for op in required_lookups {
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
}
for range in required_range_checks {
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
for _ in 0..num_dynamic_lookups {
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
)?;
}
Ok(base_gate)
}
@@ -1216,6 +1089,7 @@ impl Model {
let instance_idx = vars.get_instance_idx();
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
@@ -1285,7 +1159,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
@@ -1327,18 +1201,29 @@ impl Model {
};
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants()
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1481,28 +1366,14 @@ impl Model {
pub fn dummy_layout(
&self,
run_args: &RunArgs,
input_shapes: &[Vec<usize>],
) -> Result<(usize, usize, usize), Box<dyn Error>> {
info!("calculating num of constraints using dummy model layout...");
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = input_shapes
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
results.insert(*input_idx, vec![inputs[i].clone()]);
@@ -1515,7 +1386,7 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
@@ -1526,27 +1397,26 @@ impl Model {
ValType::Constant(Fp::ONE)
};
let comparator = outputs
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|x| {
let mut v: ValTensor<Fp> =
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
v.reshape(x.dims())?;
Ok(v)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
let _ = outputs
.into_iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o, c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>();
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
@@ -1558,7 +1428,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
@@ -1567,11 +1437,29 @@ impl Model {
region.total_constants().to_string().red()
);
Ok((
region.row(),
region.linear_coord(),
region.total_constants(),
))
let outputs = outputs
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
let res = DummyPassRes {
num_rows: region.row(),
linear_coord: region.linear_coord(),
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
max_range_size: region.max_range_size(),
num_dynamic_lookups: region.dynamic_lookup_index(),
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
outputs,
};
Ok(res)
}
/// Retrieves all constants from the model.

View File

@@ -12,16 +12,12 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::trace;
use serde::Deserialize;
use serde::Serialize;
@@ -94,10 +90,6 @@ impl Op<Fp> for Rescaled {
Op::<Fp>::out_scale(&*self.inner, in_scales)
}
fn required_lookups(&self) -> Vec<LookupOp> {
self.inner.required_lookups()
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -126,12 +118,14 @@ impl Op<Fp> for Rescaled {
pub struct RebaseScale {
/// The operation that has to be rescaled.
pub inner: Box<SupportedOp>,
/// the multiplier applied to the node output
pub multiplier: f64,
/// rebase op
pub rebase_op: HybridOp,
/// scale being rebased to
pub target_scale: i32,
/// The original scale of the operation's inputs.
pub original_scale: i32,
/// multiplier
pub multiplier: f64,
}
impl RebaseScale {
@@ -141,6 +135,7 @@ impl RebaseScale {
global_scale: crate::Scale,
op_out_scale: crate::Scale,
scale_rebase_multiplier: u32,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
@@ -149,10 +144,15 @@ impl RebaseScale {
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
})
} else {
@@ -160,6 +160,10 @@ impl RebaseScale {
inner: Box::new(inner),
target_scale: global_scale * scale_rebase_multiplier as i32,
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
})
}
@@ -173,15 +177,21 @@ impl RebaseScale {
inner: SupportedOp,
target_scale: crate::Scale,
op_out_scale: crate::Scale,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier,
original_scale: op.original_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
SupportedOp::RebaseScale(RebaseScale {
@@ -189,6 +199,10 @@ impl RebaseScale {
target_scale,
multiplier,
original_scale: op_out_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
} else {
@@ -203,19 +217,17 @@ impl Op<Fp> for RebaseScale {
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let ri = res.output.map(felt_to_i128);
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
res.output = rescaled.map(i128_to_felt);
res.intermediate_lookups.push(ri);
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
Ok(res)
}
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}) ({})",
"REBASED (div={:?}, rebasing_op={}) ({})",
self.multiplier,
<HybridOp as Op<Fp>>::as_string(&self.rebase_op),
self.inner.as_string()
)
}
@@ -224,14 +236,6 @@ impl Op<Fp> for RebaseScale {
Ok(self.target_scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
let mut lookups = self.inner.required_lookups();
lookups.push(LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
});
lookups
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -241,16 +245,8 @@ impl Op<Fp> for RebaseScale {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or("no layout")?;
Ok(Some(crate::circuit::layouts::nonlinearity(
config,
region,
&[original_res],
&LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
},
)?))
.ok_or("no inner layout")?;
self.rebase_op.layout(config, region, &[original_res])
}
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
@@ -433,10 +429,6 @@ impl Op<Fp> for SupportedOp {
self
}
fn required_lookups(&self) -> Vec<LookupOp> {
self.as_op().required_lookups()
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
self.as_op().out_scale(in_scales)
}
@@ -470,14 +462,7 @@ impl Tabled for Node {
fn headers() -> Vec<std::borrow::Cow<'static, str>> {
let mut headers = Vec::with_capacity(Self::LENGTH);
for i in [
"idx",
"opkind",
"out_scale",
"inputs",
"out_dims",
"required_lookups",
] {
for i in ["idx", "opkind", "out_scale", "inputs", "out_dims"] {
headers.push(std::borrow::Cow::Borrowed(i));
}
headers
@@ -490,14 +475,6 @@ impl Tabled for Node {
fields.push(std::borrow::Cow::Owned(self.out_scale.to_string()));
fields.push(std::borrow::Cow::Owned(display_vector(&self.inputs)));
fields.push(std::borrow::Cow::Owned(display_vector(&self.out_dims)));
fields.push(std::borrow::Cow::Owned(format!(
"{:?}",
self.opkind
.required_lookups()
.iter()
.map(<LookupOp as Op<Fp>>::as_string)
.collect_vec()
)));
fields
}
}
@@ -520,6 +497,7 @@ impl Node {
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
#[cfg(not(target_arch = "wasm32"))]
#[allow(clippy::too_many_arguments)]
pub fn new(
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
other_nodes: &mut BTreeMap<usize, super::NodeType>,
@@ -527,9 +505,9 @@ impl Node {
param_visibility: &Visibility,
idx: usize,
symbol_values: &SymbolValues,
div_rebasing: bool,
rebase_frac_zero_constants: bool,
) -> Result<Self, Box<dyn Error>> {
use log::warn;
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -567,6 +545,7 @@ impl Node {
node.clone(),
&mut inputs,
symbol_values,
rebase_frac_zero_constants,
)?; // parses the op name
// we can only take the inputs as mutable once -- so we need to collect them first
@@ -622,8 +601,6 @@ impl Node {
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
}
} else {
warn!("input {} not found for rescaling, skipping ...", input);
}
}
@@ -631,7 +608,13 @@ impl Node {
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
opkind = RebaseScale::rebase(
opkind,
global_scale,
out_scale,
scales.rebase_multiplier,
div_rebasing,
);
out_scale = opkind.out_scale(in_scales)?;

View File

@@ -71,8 +71,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
let int_rep = crate::fieldutils::felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier - shift;
float_rep
int_rep as f64 / multiplier - shift
}
/// Converts a scale (log base 2) to a fixed point multiplier.
@@ -244,6 +243,7 @@ pub fn new_op_from_onnx(
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
inputs: &mut [super::NodeType],
symbol_values: &SymbolValues,
rebase_frac_zero_constants: bool,
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
use crate::circuit::InputType;
@@ -262,7 +262,9 @@ pub fn new_op_from_onnx(
inputs[index].bump_scale(scale);
c.rebase_scale(scale)?;
inputs[index].replace_opkind(SupportedOp::Constant(c.clone()));
Ok(SupportedOp::Linear(PolyOp::Identity))
Ok(SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(scale),
}))
} else {
Ok(default_op)
}
@@ -283,8 +285,8 @@ pub fn new_op_from_onnx(
"shift left".to_string(),
)));
}
SupportedOp::Nonlinear(LookupOp::Div {
denom: crate::circuit::utils::F32(1.0 / 2.0f32.powf(raw_values[0])),
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
@@ -305,8 +307,8 @@ pub fn new_op_from_onnx(
"shift right".to_string(),
)));
}
SupportedOp::Nonlinear(LookupOp::Div {
denom: crate::circuit::utils::F32(2.0f32.powf(raw_values[0])),
SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(Box::new(GraphError::OpMismatch(
@@ -437,17 +439,16 @@ pub fn new_op_from_onnx(
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
dim: axis,
constant_idx: None,
});
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -476,17 +477,16 @@ pub fn new_op_from_onnx(
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
dim: axis,
constant_idx: None,
});
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// Raw values are always f32
let raw_value = extract_tensor_value(op.0)?;
// If bool or a tensor dimension then don't scale
let constant_scale = match dt {
let mut constant_scale = match dt {
DatumType::Bool
| DatumType::TDim
| DatumType::I64
@@ -560,6 +560,12 @@ pub fn new_op_from_onnx(
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
};
// if all raw_values are round then set scale to 0
let all_round = raw_value.iter().all(|x| (x).fract() == 0.0);
if all_round && rebase_frac_zero_constants {
constant_scale = 0;
}
// Quantize the raw value
let quantized_value =
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
@@ -666,8 +672,10 @@ pub fn new_op_from_onnx(
if unit == 0. {
SupportedOp::Nonlinear(LookupOp::ReLU)
} else {
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Max {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
}
@@ -708,8 +716,11 @@ pub fn new_op_from_onnx(
deleted_indices.push(const_idx);
}
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Min {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
} else {
@@ -717,17 +728,13 @@ pub fn new_op_from_onnx(
}
}
"Recip" => {
// Extract the slope layer hyperparams
let in_scale = inputs[0].out_scales()[0];
let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0];
let additional_scale = if scale_diff > 0 {
scale_to_multiplier(scale_diff)
} else {
1.0
};
SupportedOp::Nonlinear(LookupOp::Recip {
scale: (scale_to_multiplier(in_scale).powf(2.0) * additional_scale).into(),
let max_scale = std::cmp::max(scales.get_max(), in_scale);
// If the input scale is larger than the params scale
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
use_range_check_for_int: true,
})
}
@@ -752,7 +759,9 @@ pub fn new_op_from_onnx(
"Scan" => {
return Err("scan should never be analyzed explicitly".into());
}
"QuantizeLinearU8" | "DequantizeLinearF32" => SupportedOp::Linear(PolyOp::Identity),
"QuantizeLinearU8" | "DequantizeLinearF32" => {
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
@@ -857,11 +866,11 @@ pub fn new_op_from_onnx(
}),
)?
} else {
SupportedOp::Linear(PolyOp::Identity)
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
}
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
SupportedOp::Linear(PolyOp::Identity)
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
}
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
}
@@ -886,12 +895,15 @@ pub fn new_op_from_onnx(
let const_idx = const_idx[0];
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
op = SupportedOp::Nonlinear(LookupOp::Div {
// we invert the constant for division
denom: crate::circuit::utils::F32(1. / c.raw_values[0]),
})
// if not divisible by 2 then we need to add a range check
let raw_values = 1.0 / c.raw_values[0];
if raw_values.log2().fract() == 0.0 {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
op = SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
});
}
}
}
}

View File

@@ -1,4 +1,5 @@
use std::error::Error;
use std::fmt::Display;
use crate::tensor::TensorType;
use crate::tensor::{ValTensor, VarTensor};
@@ -14,6 +15,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use super::*;
@@ -40,6 +42,33 @@ pub enum Visibility {
Fixed,
}
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "kzgcommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed {
hash_is_public,
outlets,
} => {
if *hash_is_public {
write!(f, "hashed/public")
} else {
write!(f, "hashed/private/{}", outlets.iter().join(","))
}
}
}
}
}
impl ToFlags for Visibility {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
@@ -202,17 +231,6 @@ impl Visibility {
vec![]
}
}
impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "kzgcommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed { .. } => write!(f, "hashed"),
}
}
}
/// Represents the scale of the model input, model parameters.
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
@@ -237,6 +255,11 @@ impl VarScales {
std::cmp::max(self.input, self.params)
}
///
pub fn get_min(&self) -> crate::Scale {
std::cmp::min(self.input, self.params)
}
/// Place in [VarScales] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
Ok(Self {
@@ -397,20 +420,31 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
}
/// Allocate all columns that will be assigned to by a model.
pub fn new(
cs: &mut ConstraintSystem<F>,
logrows: usize,
var_len: usize,
num_inner_cols: usize,
num_constants: usize,
module_requires_fixed: bool,
) -> Self {
info!("number of blinding factors: {}", cs.blinding_factors());
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
let advices = (0..3)
let logrows = params.run_args.logrows as usize;
let var_len = params.total_assignments;
let num_inner_cols = params.run_args.num_inner_cols;
let num_constants = params.total_const_size;
let module_requires_fixed = params.module_requires_fixed();
let requires_dynamic_lookup = params.requires_dynamic_lookup();
let dynamic_lookup_size = params.total_dynamic_col_size;
let mut advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
.collect_vec();
if requires_dynamic_lookup {
for _ in 0..2 {
let dynamic_lookup = VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup should only have one block");
};
advices.push(dynamic_lookup);
}
}
debug!(
"model uses {} advice blocks (size={})",
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),

View File

@@ -7,7 +7,6 @@
overflowing_literals,
path_statements,
patterns_in_fns_without_body,
private_in_public,
unconditional_recursion,
unused,
unused_allocation,
@@ -29,10 +28,11 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use circuit::Tolerance;
use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
use graph::Visibility;
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
pub mod circuit;
@@ -71,11 +71,34 @@ pub mod tensor;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
/// The denominator in the fixed point representation used when quantizing inputs
pub type Scale = i32;
#[cfg(not(target_arch = "wasm32"))]
// Buf writer capacity
lazy_static! {
/// The capacity of the buffer used for writing to disk
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(target_arch = "wasm32")]
const EZKL_KEY_FORMAT: &str = "raw-bytes";
#[cfg(target_arch = "wasm32")]
const EZKL_BUF_CAPACITY: &usize = &8000;
/// Parameters specific to a proving run
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[arg(short = 'T', long, default_value = "0")]
@@ -90,8 +113,8 @@ pub struct RunArgs {
#[arg(long, default_value = "1")]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
pub lookup_range: (i128, i128),
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17")]
pub logrows: u32,
@@ -99,7 +122,7 @@ pub struct RunArgs {
#[arg(short = 'N', long, default_value = "2")]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, hashed
#[arg(long, default_value = "private")]
@@ -110,6 +133,15 @@ pub struct RunArgs {
/// Flags whether params are public, private, hashed
#[arg(long, default_value = "private")]
pub param_visibility: Visibility,
#[arg(long, default_value = "false")]
/// Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
/// Should constants with 0.0 fraction be rebased to scale 0
#[arg(long, default_value = "false")]
pub rebase_frac_zero_constants: bool,
/// check mode (safe, unsafe, etc)
#[arg(long, default_value = "unsafe")]
pub check_mode: CheckMode,
}
impl Default for RunArgs {
@@ -126,6 +158,9 @@ impl Default for RunArgs {
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
}
}
}
@@ -145,6 +180,9 @@ impl RunArgs {
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
Ok(())
}
@@ -169,34 +207,15 @@ fn parse_key_val<T, U>(
s: &str,
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T: std::str::FromStr + std::fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
U: std::str::FromStr,
U: std::str::FromStr + std::fmt::Debug,
U::Err: std::error::Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
/// Parse a tuple
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr + Clone,
T::Err: std::error::Error + Send + Sync + 'static,
{
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
let res = res
.map(|x| {
// remove blank space
let x = x.trim();
x.parse::<T>()
})
.collect::<Result<Vec<_>, _>>()?;
if res.len() != 2 {
return Err("invalid tuple".into());
}
Ok((res[0].clone(), res[1].clone()))
.find("->")
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
}

View File

@@ -8,10 +8,11 @@ use crate::circuit::CheckMode;
use crate::graph::GraphWitness;
use crate::pfsys::evm::aggregation::PoseidonTranscript;
use crate::tensor::TensorType;
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey,
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG};
@@ -39,9 +40,19 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
use std::ops::Deref;
use std::path::PathBuf;
use thiserror::Error as thisError;
use tosubcommand::ToFlags;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
match s {
"processed" => halo2_proofs::SerdeFormat::Processed,
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
_ => panic!("invalid serde format"),
}
}
#[allow(missing_docs)]
#[derive(
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
@@ -52,6 +63,25 @@ pub enum ProofType {
ForAggr,
}
impl std::fmt::Display for ProofType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ProofType::Single => "single",
ProofType::ForAggr => "for-aggr",
}
)
}
}
impl ToFlags for ProofType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<ProofType> for TranscriptType {
fn from(val: ProofType) -> Self {
match val {
@@ -154,6 +184,25 @@ pub enum TranscriptType {
EVM,
}
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TranscriptType::Poseidon => "poseidon",
TranscriptType::EVM => "evm",
}
)
}
}
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {
fn to_object(&self, py: Python) -> PyObject {
@@ -167,8 +216,8 @@ impl ToPyObject for TranscriptType {
#[cfg(feature = "python-bindings")]
///
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
let g1affine_x = field_to_vecu64_montgomery(&g1affine.x);
let g1affine_y = field_to_vecu64_montgomery(&g1affine.y);
let g1affine_x = field_to_string(&g1affine.x);
let g1affine_y = field_to_string(&g1affine.y);
g1affine_dict.set_item("x", g1affine_x).unwrap();
g1affine_dict.set_item("y", g1affine_y).unwrap();
}
@@ -178,24 +227,24 @@ use halo2curves::bn256::G1;
#[cfg(feature = "python-bindings")]
///
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
let g1_x = field_to_vecu64_montgomery(&g1.x);
let g1_y = field_to_vecu64_montgomery(&g1.y);
let g1_z = field_to_vecu64_montgomery(&g1.z);
let g1_x = field_to_string(&g1.x);
let g1_y = field_to_string(&g1.y);
let g1_z = field_to_string(&g1.z);
g1_dict.set_item("x", g1_x).unwrap();
g1_dict.set_item("y", g1_y).unwrap();
g1_dict.set_item("z", g1_z).unwrap();
}
/// converts fp into `Vec<u64>` in Montgomery form
pub fn field_to_vecu64_montgomery<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> [u64; 4] {
/// converts fp into a little endian Hex string
pub fn field_to_string<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
let repr = serde_json::to_string(&fp).unwrap();
let b: [u64; 4] = serde_json::from_str(&repr).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
b
}
/// converts `Vec<u64>` in Montgomery form into fp
pub fn vecu64_to_field_montgomery<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
b: &[u64; 4],
/// converts a little endian Hex string into a field element
pub fn string_to_field<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
b: &String,
) -> F {
let repr = serde_json::to_string(&b).unwrap();
let fp: F = serde_json::from_str(&repr).unwrap();
@@ -256,10 +305,10 @@ where
{
fn to_object(&self, py: Python) -> PyObject {
let dict = PyDict::new(py);
let field_elems: Vec<Vec<[u64; 4]>> = self
let field_elems: Vec<Vec<String>> = self
.instances
.iter()
.map(|x| x.iter().map(|fp| field_to_vecu64_montgomery(fp)).collect())
.map(|x| x.iter().map(|fp| field_to_string(fp)).collect())
.collect::<Vec<_>>();
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
@@ -306,10 +355,16 @@ where
}
}
/// create hex proof from proof
pub fn create_hex_proof(&mut self) {
let hex_proof = hex::encode(&self.proof);
self.hex_proof = Some(format!("0x{}", hex_proof));
}
/// Saves the Proof to a specified `proof_path`.
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
let file = std::fs::File::create(proof_path)?;
let mut writer = BufWriter::new(file);
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(&mut writer, &self)?;
Ok(())
}
@@ -322,8 +377,10 @@ where
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
{
trace!("reading proof");
let data = std::fs::read_to_string(proof_path)?;
serde_json::from_str(&data).map_err(|e| e.into())
let file = std::fs::File::open(proof_path)?;
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let proof: Self = serde_json::from_reader(reader)?;
Ok(proof)
}
}
@@ -427,6 +484,7 @@ where
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
@@ -438,7 +496,7 @@ where
// Initialize verifying key
let now = Instant::now();
trace!("preparing VK");
let vk = keygen_vk(params, &empty_circuit)?;
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
let elapsed = now.elapsed();
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());
@@ -548,6 +606,7 @@ where
verifier_params,
pk.get_vk(),
strategy,
verifier_params.n(),
)?;
}
let elapsed = now.elapsed();
@@ -594,6 +653,7 @@ where
let mut snark_new = snark.clone();
// swap the proof bytes for the new ones
snark_new.proof[..proof_first_bytes.len()].copy_from_slice(&proof_first_bytes);
snark_new.create_hex_proof();
Ok(snark_new)
}
@@ -634,6 +694,7 @@ pub fn verify_proof_circuit<
params: &'params Scheme::ParamsVerifier,
vk: &VerifyingKey<Scheme::Curve>,
strategy: Strategy,
orig_n: u64,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
where
Scheme::Scalar: SerdeObject
@@ -654,7 +715,7 @@ where
trace!("instances {:?}", instances);
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
}
/// Loads a [VerifyingKey] at `path`.
@@ -670,13 +731,14 @@ where
info!("loading verification key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
let mut reader = BufReader::new(f);
VerifyingKey::<Scheme::Curve>::read::<_, C>(
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading verification key ✅");
Ok(vk)
}
/// Loads a [ProvingKey] at `path`.
@@ -692,19 +754,20 @@ where
info!("loading proving key from {:?}", path);
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
let mut reader = BufReader::new(f);
ProvingKey::<Scheme::Curve>::read::<_, C>(
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading proving key ✅");
Ok(pk)
}
/// Saves a [ProvingKey] to `path`.
pub fn save_pk<Scheme: CommitmentScheme>(
path: &PathBuf,
vk: &ProvingKey<Scheme::Curve>,
pk: &ProvingKey<Scheme::Curve>,
) -> Result<(), io::Error>
where
Scheme::Curve: SerdeObject + CurveAffine,
@@ -712,9 +775,10 @@ where
{
info!("saving proving key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving proving key ✅");
Ok(())
}
@@ -729,9 +793,10 @@ where
{
info!("saving verification key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving verification key ✅");
Ok(())
}
@@ -742,7 +807,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
) -> Result<(), io::Error> {
info!("saving parameters 💾");
let f = File::create(path)?;
let mut writer = BufWriter::new(f);
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
params.write(&mut writer)?;
writer.flush()?;
Ok(())
@@ -832,6 +897,7 @@ pub(crate) fn verify_proof_circuit_kzg<
proof: Snark<Fr, G1Affine>,
vk: &VerifyingKey<G1Affine>,
strategy: Strategy,
orig_n: u64,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
@@ -841,7 +907,7 @@ pub(crate) fn verify_proof_circuit_kzg<
_,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, params, vk, strategy),
>(&proof, params, vk, strategy, orig_n),
TranscriptType::Poseidon => verify_proof_circuit::<
Fr,
VerifierSHPLONK<'_, Bn256>,
@@ -849,7 +915,7 @@ pub(crate) fn verify_proof_circuit_kzg<
_,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, params, vk, strategy),
>(&proof, params, vk, strategy, orig_n),
}
}

View File

@@ -15,10 +15,9 @@ use crate::graph::{
use crate::pfsys::evm::aggregation::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
Snark, TranscriptType,
TranscriptType,
};
use crate::RunArgs;
use ethers::types::H160;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
@@ -26,11 +25,10 @@ use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use snark_verifier::util::arithmetic::PrimeField;
use std::str::FromStr;
use std::{fs::File, path::PathBuf};
use tokio::runtime::Runtime;
type PyFelt = [u64; 4];
type PyFelt = String;
#[pyclass]
#[derive(Debug, Clone)]
@@ -65,9 +63,9 @@ struct PyG1 {
impl From<G1> for PyG1 {
fn from(g1: G1) -> Self {
PyG1 {
x: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.y),
z: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.z),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
z: crate::pfsys::field_to_string::<Fq>(&g1.z),
}
}
}
@@ -75,9 +73,9 @@ impl From<G1> for PyG1 {
impl From<PyG1> for G1 {
fn from(val: PyG1) -> Self {
G1 {
x: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.y),
z: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.z),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
z: crate::pfsys::string_to_field::<Fq>(&val.z),
}
}
}
@@ -108,8 +106,8 @@ pub struct PyG1Affine {
impl From<G1Affine> for PyG1Affine {
fn from(g1: G1Affine) -> Self {
PyG1Affine {
x: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.y),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
}
}
}
@@ -117,8 +115,8 @@ impl From<G1Affine> for PyG1Affine {
impl From<PyG1Affine> for G1Affine {
fn from(val: PyG1Affine) -> Self {
G1Affine {
x: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.y),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
}
}
}
@@ -146,7 +144,7 @@ struct PyRunArgs {
#[pyo3(get, set)]
pub scale_rebase_multiplier: u32,
#[pyo3(get, set)]
pub lookup_range: (i128, i128),
pub lookup_range: crate::circuit::table::Range,
#[pyo3(get, set)]
pub logrows: u32,
#[pyo3(get, set)]
@@ -159,6 +157,12 @@ struct PyRunArgs {
pub param_visibility: Visibility,
#[pyo3(get, set)]
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
pub div_rebasing: bool,
#[pyo3(get, set)]
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
pub check_mode: CheckMode,
}
/// default instantiation of PyRunArgs
@@ -185,6 +189,9 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
}
}
}
@@ -203,57 +210,58 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
}
}
}
/// Converts 4 u64s to a field element
/// Converts a felt to big endian
#[pyfunction(signature = (
array,
felt,
))]
fn vecu64_to_felt(array: PyFelt) -> PyResult<String> {
Ok(format!(
"{:?}",
crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array)
))
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
}
/// Converts 4 u64s representing a field element directly to an integer
/// Converts a field element hex string to an integer
#[pyfunction(signature = (
array,
))]
fn vecu64_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array);
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts a field eleement hex string to a floating point number
#[pyfunction(signature = (
array,
scale
))]
fn vecu64_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array);
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point element to a field element hex string
#[pyfunction(signature = (
input,
scale
))]
fn float_to_vecu64(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = i128_to_felt(int_rep);
Ok(crate::pfsys::field_to_vecu64_montgomery::<Fr>(&felt))
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
/// Converts a buffer to vector of field elements
#[pyfunction(signature = (
buffer
))]
@@ -306,7 +314,10 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
let field_elements: Vec<String> = field_elements.iter().map(|x| format!("{:?}", x)).collect();
let field_elements: Vec<String> = field_elements
.iter()
.map(|x| crate::pfsys::field_to_string::<Fr>(x))
.collect();
Ok(field_elements)
}
@@ -318,7 +329,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::vecu64_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let output =
@@ -329,7 +340,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let hash = output[0]
.iter()
.map(crate::pfsys::field_to_vecu64_montgomery::<Fr>)
.map(crate::pfsys::field_to_string::<Fr>)
.collect::<Vec<_>>();
Ok(hash)
}
@@ -337,8 +348,8 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
/// Generate a kzg commitment.
#[pyfunction(signature = (
message,
vk_path,
settings_path,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
fn kzg_commit(
@@ -349,7 +360,7 @@ fn kzg_commit(
) -> PyResult<Vec<PyG1Affine>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::vecu64_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let settings = GraphSettings::load(&settings_path)
@@ -387,9 +398,9 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
/// Generates a vk from a pk for a model circuit and saves it to a file
#[pyfunction(signature = (
path_to_pk,
circuit_settings_path,
vk_output_path
path_to_pk=PathBuf::from(DEFAULT_PK),
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
@@ -413,8 +424,8 @@ fn gen_vk_from_pk_single(
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
#[pyfunction(signature = (
path_to_pk,
vk_output_path
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(path_to_pk, ())
@@ -511,7 +522,9 @@ fn gen_settings(
target = CalibrationTarget::default(), // default is "resources
lookup_safety_margin = DEFAULT_LOOKUP_SAFETY_MARGIN.parse().unwrap(),
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
data: PathBuf,
@@ -520,7 +533,9 @@ fn calibrate_settings(
target: CalibrationTarget,
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -529,6 +544,8 @@ fn calibrate_settings(
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.map_err(|e| {
@@ -543,7 +560,7 @@ fn calibrate_settings(
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_MODEL),
output=None,
output=PathBuf::from(DEFAULT_WITNESS),
vk_path=None,
srs_path=None,
))]
@@ -604,7 +621,8 @@ fn mock_aggregate(
vk_path=PathBuf::from(DEFAULT_VK),
pk_path=PathBuf::from(DEFAULT_PK),
srs_path=None,
witness_path = None
witness_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
))]
fn setup(
model: PathBuf,
@@ -612,8 +630,17 @@ fn setup(
pk_path: PathBuf,
srs_path: Option<PathBuf>,
witness_path: Option<PathBuf>,
compress_selectors: bool,
) -> Result<bool, PyErr> {
crate::execute::setup(model, srs_path, vk_path, pk_path, witness_path).map_err(|e| {
crate::execute::setup(
model,
srs_path,
vk_path,
pk_path,
witness_path,
compress_selectors,
)
.map_err(|e| {
let err_str = format!("Failed to run setup: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -661,14 +688,23 @@ fn prove(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_path=PathBuf::from(DEFAULT_VK),
srs_path=None,
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
non_reduced_srs: bool,
) -> Result<bool, PyErr> {
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
crate::execute::verify(
proof_path,
settings_path,
vk_path,
srs_path,
non_reduced_srs,
)
.map_err(|e| {
let err_str = format!("Failed to run verify: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -682,7 +718,8 @@ fn verify(
pk_path=PathBuf::from(DEFAULT_PK_AGGREGATED),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
srs_path = None
srs_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
))]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
@@ -691,6 +728,7 @@ fn setup_aggregate(
logrows: u32,
split_proofs: bool,
srs_path: Option<PathBuf>,
compress_selectors: bool,
) -> Result<bool, PyErr> {
crate::execute::setup_aggregate(
sample_snarks,
@@ -699,6 +737,7 @@ fn setup_aggregate(
srs_path,
logrows,
split_proofs,
compress_selectors,
)
.map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
@@ -794,6 +833,7 @@ fn verify_aggr(
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
))]
fn create_evm_verifier(
vk_path: PathBuf,
@@ -801,12 +841,20 @@ fn create_evm_verifier(
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
) -> Result<bool, PyErr> {
crate::execute::create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
crate::execute::create_evm_verifier(
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
)
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
@@ -872,7 +920,7 @@ fn setup_test_evm_witness(
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
private_key=None,
))]
fn deploy_evm(
addr_path: PathBuf,
@@ -889,6 +937,39 @@ fn deploy_evm(
addr_path,
optimizer_runs,
private_key,
"Halo2Verifier",
))
.map_err(|e| {
let err_str = format!("Failed to run deploy_evm: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
fn deploy_vk_evm(
addr_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> Result<bool, PyErr> {
Runtime::new()
.unwrap()
.block_on(crate::execute::deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2VerifyingKey",
))
.map_err(|e| {
let err_str = format!("Failed to run deploy_evm: {}", e);
@@ -940,26 +1021,28 @@ fn deploy_da_evm(
proof_path=PathBuf::from(DEFAULT_PROOF),
rpc_url=None,
addr_da = None,
addr_vk = None,
))]
fn verify_evm(
addr_verifier: &str,
proof_path: PathBuf,
rpc_url: Option<String>,
addr_da: Option<&str>,
addr_vk: Option<&str>,
) -> Result<bool, PyErr> {
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_verifier = H160Flag::from(addr_verifier);
let addr_da = if let Some(addr_da) = addr_da {
let addr_da = H160::from_str(addr_da).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_da = H160Flag::from(addr_da);
Some(addr_da)
} else {
None
};
let addr_vk = if let Some(addr_vk) = addr_vk {
let addr_vk = H160Flag::from(addr_vk);
Some(addr_vk)
} else {
None
};
Runtime::new()
.unwrap()
@@ -968,6 +1051,7 @@ fn verify_evm(
addr_verifier,
rpc_url,
addr_da,
addr_vk,
))
.map_err(|e| {
let err_str = format!("Failed to run verify_evm: {}", e);
@@ -985,6 +1069,7 @@ fn verify_evm(
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
))]
fn create_evm_verifier_aggr(
aggregation_settings: Vec<PathBuf>,
@@ -993,6 +1078,7 @@ fn create_evm_verifier_aggr(
abi_path: PathBuf,
logrows: u32,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
) -> Result<bool, PyErr> {
crate::execute::create_evm_aggregate_verifier(
vk_path,
@@ -1001,6 +1087,7 @@ fn create_evm_verifier_aggr(
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
)
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
@@ -1009,32 +1096,21 @@ fn create_evm_verifier_aggr(
Ok(true)
}
/// print hex representation of a proof
#[pyfunction(signature = (proof_path))]
fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
let hex_str = hex::encode(proof.proof);
Ok(format!("0x{}", hex_str))
}
// Python Module
#[pymodule]
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// NOTE: DeployVerifierEVM and SendProofEVM will be implemented in python in pyezkl
pyo3_log::init();
m.add_class::<PyRunArgs>()?;
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_function(wrap_pyfunction!(vecu64_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?;
m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?;
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;
@@ -1055,9 +1131,9 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;

View File

@@ -30,11 +30,11 @@ use halo2_proofs::{
poly::Rotation,
};
use itertools::Itertools;
use std::cmp::max;
use std::error::Error;
use std::fmt::Debug;
use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
use thiserror::Error;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
let mut inner = self.inner.clone();
let remainder = self.len() % n;
if remainder != 0 {
inner.resize(self.len() + n - remainder, T::zero().unwrap());
inner.resize(self.len() + n - remainder, pad);
}
Tensor::new(Some(&inner), &[inner.len()])
}
@@ -1452,6 +1452,43 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
}
}
// implement remainder
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
type Output = Result<Tensor<T>, TensorError>;
/// Elementwise remainder of a tensor with another tensor.
/// # Arguments
/// * `self` - Tensor
/// * `rhs` - Tensor
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use std::ops::Rem;
/// let x = Tensor::<i32>::new(
/// Some(&[4, 1, 4, 1, 1, 4]),
/// &[2, 3],
/// ).unwrap();
/// let y = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.rem(y).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn rem(self, rhs: Self) -> Self::Output {
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() % r;
});
Ok(lhs)
}
}
/// Returns the broadcasted shape of two tensors
/// ```
/// use ezkl::tensor::get_broadcasted_shape;

View File

@@ -243,7 +243,7 @@ pub fn and<
/// Some(&[1, 0, 1, 0, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = equals(&a, &b).unwrap().0;
/// let result = equals(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
@@ -260,7 +260,7 @@ pub fn equals<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let a = a.clone();
let b = b.clone();
@@ -268,7 +268,7 @@ pub fn equals<
let result = nonlinearities::kronecker_delta(&diff);
Ok((result, vec![diff]))
Ok(result)
}
/// Greater than operation.
@@ -289,7 +289,7 @@ pub fn equals<
/// ).unwrap();
/// let result = greater(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
pub fn greater<
T: TensorType
@@ -302,7 +302,7 @@ pub fn greater<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -311,7 +311,7 @@ pub fn greater<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok((mask, vec![mask_inter]))
Ok(mask)
}
/// Greater equals than operation.
@@ -332,7 +332,7 @@ pub fn greater<
/// ).unwrap();
/// let result = greater_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
pub fn greater_equal<
T: TensorType
@@ -345,7 +345,7 @@ pub fn greater_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -354,7 +354,7 @@ pub fn greater_equal<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok((mask, vec![mask_inter]))
Ok(mask)
}
/// Less than to operation.
@@ -375,7 +375,7 @@ pub fn greater_equal<
/// ).unwrap();
/// let result = less(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
///
pub fn less<
@@ -389,7 +389,7 @@ pub fn less<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
// a < b <=> b > a
greater(b, a)
}
@@ -412,7 +412,7 @@ pub fn less<
/// ).unwrap();
/// let result = less_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result.0, expected);
/// assert_eq!(result, expected);
/// ```
///
pub fn less_equal<
@@ -426,7 +426,7 @@ pub fn less_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
) -> Result<Tensor<T>, TensorError> {
// a < b <=> b > a
greater_equal(b, a)
}
@@ -950,8 +950,7 @@ pub fn neg<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sy
/// Elementwise multiplies multiple tensors.
/// # Arguments
///
/// * `a` - Tensor
/// * `b` - Tensor
/// * `t` - Tensors
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
@@ -2301,12 +2300,12 @@ pub fn deconv<
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap();
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
///
/// // This time with normalization
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
/// ```
@@ -2316,7 +2315,7 @@ pub fn sumpool(
stride: (usize, usize),
kernel_shape: (usize, usize),
normalize: bool,
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
) -> Result<Tensor<i128>, TensorError> {
let image_dims = image.dims();
let batch_size = image_dims[0];
let image_channels = image_dims[1];
@@ -2346,15 +2345,12 @@ pub fn sumpool(
let mut combined = res.combine()?;
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
let mut inter = vec![];
if normalize {
inter.push(combined.clone());
let norm = kernel.len();
combined = nonlinearities::const_div(&combined, norm as f64);
}
Ok((combined, inter))
Ok(combined)
}
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
@@ -3049,11 +3045,7 @@ pub mod nonlinearities {
}
/// softmax layout
pub fn softmax_axes(
a: &Tensor<i128>,
scale: f64,
axes: &[usize],
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
// we want this to be as small as possible so we set the output scale to 1
let dims = a.dims();
@@ -3061,8 +3053,6 @@ pub mod nonlinearities {
return softmax(a, scale);
}
let mut intermediate_values = vec![];
let cartesian_coord = dims[..dims.len() - 1]
.iter()
.map(|x| 0..*x)
@@ -3085,8 +3075,7 @@ pub mod nonlinearities {
let res = softmax(&softmax_input, scale);
outputs.push(res.0);
intermediate_values.extend(res.1);
outputs.push(res);
}
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
@@ -3094,7 +3083,7 @@ pub mod nonlinearities {
.combine()
.unwrap();
res.reshape(dims).unwrap();
(res, intermediate_values)
res
}
/// Applies softmax
@@ -3111,24 +3100,20 @@ pub mod nonlinearities {
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = softmax(&x, 128.0).0;
/// let result = softmax(&x, 128.0);
/// // doubles the scale of the input
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let mut intermediate_values = vec![];
intermediate_values.push(a.clone());
let exp = exp(a, scale);
let sum = sum(&exp).unwrap();
intermediate_values.push(sum.clone());
let inv_denom = recip(&sum, scale.powf(2.0));
let inv_denom = recip(&sum, scale, scale);
((exp * inv_denom).unwrap(), intermediate_values)
(exp * inv_denom).unwrap()
}
/// Applies range_check_percent
@@ -3163,7 +3148,7 @@ pub mod nonlinearities {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let scale = input_scale * output_scale;
let diff: Tensor<i128> = sub(t).unwrap();
let recip = recip(&t[0], scale as f64);
let recip = recip(&t[0], input_scale as f64, output_scale as f64);
let product = mult(&[diff, recip]).unwrap();
let _tol = ((tol / 100.0) * scale as f32).round() as f64;
let upper_bound = greater_than(&product, _tol);
@@ -3774,14 +3759,39 @@ pub mod nonlinearities {
/// &[2, 3],
/// ).unwrap();
/// let k = 2_f64;
/// let result = recip(&x, k);
/// let result = recip(&x, 1.0, k);
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn recip(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
pub fn recip(a: &Tensor<i128>, input_scale: f64, out_scale: f64) -> Tensor<i128> {
a.par_enum_map(|_, a_i| {
let denom = (1_f64) / (a_i as f64 + f64::EPSILON);
let d_inv_x = scale * denom;
let rescaled = (a_i as f64) / input_scale;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as i128)
})
.unwrap()
}
/// Elementwise inverse.
/// # Arguments
/// * `out_scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as i128)
})
.unwrap()

View File

@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Calls `pad_to_zero_rem` on the inner tensor.
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.pad_to_zero_rem(n)?;
*v = v.pad_to_zero_rem(n, pad)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
@@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
if indices.is_empty() {
return Ok(());
} else {
return Err(TensorError::WrongMethod);
}
}
}
Ok(())
@@ -871,3 +875,30 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
/// inverts the inner values
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
let mut cloned_self = self.clone();
match &mut cloned_self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.map(|x| match x {
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
ValType::AssignedValue(v.value_field().invert())
}
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
});
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
}
};
Ok(cloned_self)
}
}

View File

@@ -33,10 +33,7 @@ pub enum VarTensor {
impl VarTensor {
///
pub fn is_advice(&self) -> bool {
match self {
VarTensor::Advice { .. } => true,
_ => false,
}
matches!(self, VarTensor::Advice { .. })
}
///

View File

@@ -69,19 +69,19 @@ pub fn encodeVerifierCalldata(
Ok(encoded)
}
/// Converts 4 u64s to a field element
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vecU64ToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts 4 u64s representing a field element directly to an integer
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vecU64ToInt(
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
@@ -92,10 +92,10 @@ pub fn vecU64ToInt(
))
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vecU64ToFloat(
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
@@ -106,26 +106,26 @@ pub fn vecU64ToFloat(
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToVecU64(
pub fn floatToFelt(
input: f64,
scale: crate::Scale,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = i128_to_felt(int_rep);
let vec = crate::pfsys::field_to_vecu64_montgomery::<halo2curves::bn256::Fr>(&felt);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize vecu64_montgomery{}", e)),
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfVecU64(
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice
@@ -211,7 +211,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward(&mut input, None, None)
.forward(&mut input, None, None, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)
@@ -224,6 +224,7 @@ pub fn genWitness(
pub fn genVk(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
compress_selectors: bool,
) -> Result<Vec<u8>, JsError> {
// Read in kzg params
let mut reader = std::io::BufReader::new(&params_ser[..]);
@@ -235,9 +236,13 @@ pub fn genVk(
.map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?;
// Create verifying key
let vk = create_vk_wasm::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn std::error::Error>::from)
.map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?;
let vk = create_vk_wasm::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(Box::<dyn std::error::Error>::from)
.map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
@@ -306,13 +311,19 @@ pub fn verify(
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
circuit_settings.clone(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
let result = verify_proof_circuit_kzg(
params.verifier_params(),
snark,
&vk,
strategy,
1 << circuit_settings.run_args.logrows,
);
match result {
Ok(_) => Ok(true),
@@ -382,15 +393,6 @@ pub fn prove(
.into_bytes())
}
/// print hex representation of a proof
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let hex_str = hex::encode(proof.proof);
Ok(format!("0x{}", hex_str))
}
// VALIDATION FUNCTIONS
/// Witness file validation
@@ -497,6 +499,7 @@ pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsErro
pub fn create_vk_wasm<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
@@ -506,7 +509,7 @@ where
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the verifying key
let vk = keygen_vk(params, &empty_circuit)?;
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
Ok(vk)
}
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ def get_ezkl_output(witness_file, settings_file):
outputs = witness_output['outputs']
with open(settings_file) as f:
settings = json.load(f)
ezkl_outputs = [[ezkl.vecu64_to_float(
ezkl_outputs = [[ezkl.felt_to_float(
outputs[i][j], settings['model_output_scales'][i]) for j in range(len(outputs[i]))] for i in range(len(outputs))]
return ezkl_outputs
@@ -78,16 +78,20 @@ def compare_outputs(zk_output, onnx_output):
zip_object = zip(np.array(zk_output).flatten(),
np.array(onnx_output).flatten())
for list1_i, list2_i in zip_object:
for (i, (list1_i, list2_i)) in enumerate(zip_object):
if list1_i == 0.0 and list2_i == 0.0:
res.append(0)
else:
diff = list1_i - list2_i
res.append(100 * (diff) / (list2_i))
# iterate and print the diffs if they are greater than 0.0
if abs(diff) > 0.0:
print("------- index: ", i)
print("------- diff: ", diff)
print("------- zk_output: ", list1_i)
print("------- onnx_output: ", list2_i)
print("res: ", res)
return np.mean(np.abs(res))
return res
if __name__ == '__main__':
@@ -107,6 +111,9 @@ if __name__ == '__main__':
onnx_output = get_onnx_output(model_file, input_file)
# compare the outputs
percentage_difference = compare_outputs(ezkl_output, onnx_output)
mean_percentage_difference = np.mean(np.abs(percentage_difference))
max_percentage_difference = np.max(np.abs(percentage_difference))
# print the percentage difference
print("mean percent diff: ", percentage_difference)
assert percentage_difference < target, "Percentage difference is too high"
print("mean percent diff: ", mean_percentage_difference)
print("max percent diff: ", max_percentage_difference)
assert mean_percentage_difference < target, "Percentage difference is too high"

View File

@@ -118,38 +118,38 @@ mod py_tests {
}
const TESTS: [&str; 32] = [
"proof_splitting.ipynb", // 0
"variance.ipynb",
"proof_splitting.ipynb",
"mnist_gan_proof_splitting.ipynb",
"mnist_gan.ipynb",
// "mnist_vae.ipynb",
"keras_simple_demo.ipynb",
"hashed_vis.ipynb",
"mnist_gan_proof_splitting.ipynb", // 4
"hashed_vis.ipynb", // 5
"simple_demo_all_public.ipynb",
"data_attest.ipynb",
"little_transformer.ipynb",
"simple_demo_aggregated_proofs.ipynb",
"ezkl_demo.ipynb",
"ezkl_demo.ipynb", // 10
"lstm.ipynb",
"set_membership.ipynb",
"set_membership.ipynb", // 12
"decision_tree.ipynb",
"random_forest.ipynb",
"gradient_boosted_trees.ipynb",
"gradient_boosted_trees.ipynb", // 15
"xgboost.ipynb",
"lightgbm.ipynb",
"svm.ipynb",
"simple_demo_public_input_output.ipynb",
"simple_demo_public_network_output.ipynb",
"simple_demo_public_network_output.ipynb", // 20
"gcn.ipynb",
"linear_regression.ipynb",
"stacked_regression.ipynb",
"data_attest_hashed.ipynb",
"kzg_vis.ipynb",
"kzg_vis.ipynb", // 25
"kmeans.ipynb",
"solvency.ipynb",
"sklearn_mlp.ipynb",
"generalized_inverse.ipynb",
"mnist_classifier.ipynb",
"mnist_classifier.ipynb", // 30
"world_rotation.ipynb",
];

View File

@@ -56,9 +56,9 @@ def test_poseidon_hash():
Test for poseidon_hash
"""
message = [1.0, 2.0, 3.0, 4.0]
message = [ezkl.float_to_vecu64(x, 7) for x in message]
message = [ezkl.float_to_felt(x, 7) for x in message]
res = ezkl.poseidon_hash(message)
assert ezkl.vecu64_to_felt(
assert ezkl.felt_to_big_endian(
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
@@ -70,14 +70,14 @@ def test_field_serialization():
input = 890
scale = 7
felt = ezkl.float_to_vecu64(input, scale)
roundtrip_input = ezkl.vecu64_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
input = -700
scale = 7
felt = ezkl.float_to_vecu64(input, scale)
roundtrip_input = ezkl.vecu64_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
@@ -88,12 +88,12 @@ def test_buffer_to_felts():
buffer = bytearray("a sample string!", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061"
assert felts == [ref_felt_1]
assert ezkl.felt_to_big_endian(felts[0]) == ref_felt_1
buffer = bytearray("a sample string!"+"high", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968"
assert felts == [ref_felt_1, ref_felt_2]
assert [ezkl.felt_to_big_endian(felts[0]), ezkl.felt_to_big_endian(felts[1])] == [ref_felt_1, ref_felt_2]
def test_gen_srs():
@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
@@ -147,13 +147,13 @@ def test_calibrate():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
output_path = os.path.join(
@@ -183,7 +183,7 @@ def test_model_compile():
model_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'network.onnx'
)
compiled_model_path = os.path.join(
@@ -205,7 +205,7 @@ def test_forward():
data_path = os.path.join(
examples_path,
'onnx',
'1l_average',
'1l_relu',
'input.json'
)
model_path = os.path.join(
@@ -392,9 +392,7 @@ def test_prove_evm():
assert res['transcript_type'] == 'EVM'
assert os.path.isfile(proof_path)
res = ezkl.print_proof_hex(proof_path)
# to figure out a better way of testing print_proof_hex
assert type(res) == str
def test_create_evm_verifier():

View File

@@ -8,10 +8,10 @@ mod wasm32 {
use ezkl::graph::GraphWitness;
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfVecU64, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
prove, settingsValidation, srsValidation, u8_array_to_u128_le, vecU64ToFelt, vecU64ToFloat,
vecU64ToInt, verify, vkValidation, witnessValidation,
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, genPk, genVk, genWitness, inputValidation, pkValidation,
poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -26,7 +26,7 @@ mod wasm32 {
pub const NETWORK_COMPILED: &[u8] = include_bytes!("../tests/wasm/model.compiled");
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/test.proof");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
@@ -76,21 +76,21 @@ mod wasm32 {
for i in 0..32 {
let field_element = Fr::from(i);
let serialized = serde_json::to_vec(&field_element).unwrap();
let clamped = wasm_bindgen::Clamped(serialized);
let scale = 2;
let floating_point = vecU64ToFloat(clamped.clone(), scale)
let floating_point = feltToFloat(clamped.clone(), scale)
.map_err(|_| "failed")
.unwrap();
assert_eq!(floating_point, (i as f64) / 4.0);
let integer: i128 = serde_json::from_slice(
&vecU64ToInt(clamped.clone()).map_err(|_| "failed").unwrap(),
)
.unwrap();
let integer: i128 =
serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap())
.unwrap();
assert_eq!(integer, i as i128);
let hex_string = format!("{:?}", field_element);
let returned_string = vecU64ToFelt(clamped).map_err(|_| "failed").unwrap();
let returned_string: String = feltToBigEndian(clamped).map_err(|_| "failed").unwrap();
assert_eq!(hex_string, returned_string);
}
}
@@ -101,7 +101,7 @@ mod wasm32 {
let mut buffer = string_high.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -118,7 +118,7 @@ mod wasm32 {
let buffer = string_sample.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -133,7 +133,7 @@ mod wasm32 {
let buffer = string_concat.into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -186,6 +186,7 @@ mod wasm32 {
let vk = genVk(
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
wasm_bindgen::Clamped(SRS.to_vec()),
true,
)
.map_err(|_| "failed")
.unwrap();
@@ -206,6 +207,7 @@ mod wasm32 {
let vk = genVk(
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
wasm_bindgen::Clamped(SRS.to_vec()),
true,
)
.map_err(|_| "failed")
.unwrap();
@@ -218,6 +220,7 @@ mod wasm32 {
let vk = genVk(
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
wasm_bindgen::Clamped(SRS.to_vec()),
true,
)
.map_err(|_| "failed")
.unwrap();
@@ -255,15 +258,6 @@ mod wasm32 {
assert!(value);
}
#[wasm_bindgen_test]
async fn print_proof_hex_test() {
let proof = printProofHex(wasm_bindgen::Clamped(PROOF.to_vec()))
.map_err(|_| "failed")
.unwrap();
assert!(proof.len() > 0);
}
#[wasm_bindgen_test]
async fn verify_validations() {
// Run witness validation on network (should fail)

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -1 +1,63 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":0,"scale_rebase_multiplier":10,"lookup_range":[-2,0],"logrows":6,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_rows":16,"total_assignments":32,"total_const_size":8,"model_instance_shapes":[[1,4]],"model_output_scales":[0],"model_input_scales":[0],"module_sizes":{"kzg":[],"poseidon":[0,[0]]},"required_lookups":["ReLU"],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1702474230544}
{
"run_args": {
"tolerance": {
"val": 0.0,
"scale": 1.0
},
"input_scale": 0,
"param_scale": 0,
"scale_rebase_multiplier": 10,
"lookup_range": [
-2,
0
],
"logrows": 6,
"num_inner_cols": 2,
"variables": [
[
"batch_size",
1
]
],
"input_visibility": "Private",
"output_visibility": "Public",
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [
[
1,
4
]
],
"model_output_scales": [
0
],
"model_input_scales": [
0
],
"module_sizes": {
"kzg": [],
"poseidon": [
0,
[
0
]
]
},
"required_lookups": [
"ReLU"
],
"required_range_checks": [],
"check_mode": "UNSAFE",
"version": "0.0.0",
"num_blinding_factors": null,
"timestamp": 1702474230544
}

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,6 @@
import * as fs from 'fs/promises';
import * as fsSync from 'fs'
import JSONBig from 'json-bigint';
import { vecU64ToFelt } from '@ezkljs/engine/nodejs'
const solc = require('solc');
// import os module

Binary file not shown.

View File

@@ -1 +1 @@
{"inputs":[[[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574],[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287],[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287]]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[[[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}