mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c0c17c9be | ||
|
|
bf69b16fc1 | ||
|
|
74feb829da | ||
|
|
d429e7edab | ||
|
|
f0e5b82787 | ||
|
|
3f7261f50b | ||
|
|
678a249dcb | ||
|
|
0291eb2d0f | ||
|
|
1b637a70b0 | ||
|
|
abcd5380db | ||
|
|
076b737108 | ||
|
|
97d9832591 | ||
|
|
e0771683a6 | ||
|
|
319c222307 | ||
|
|
85ee6e7f9d | ||
|
|
4c8daf773c | ||
|
|
80041ac523 | ||
|
|
2a1ee1102c | ||
|
|
95d4fd4a70 | ||
|
|
e0d3f4f145 | ||
|
|
bceac2fab5 | ||
|
|
04d7b5feaa | ||
|
|
45fd12a04f | ||
|
|
bc7c33190f | ||
|
|
df72e01414 | ||
|
|
172e26c00d | ||
|
|
11ac120f23 | ||
|
|
0fdd92e9f3 | ||
|
|
31f58056a5 | ||
|
|
ddbcc1d2d8 | ||
|
|
feccc5feed | ||
|
|
db24577c5d | ||
|
|
bb482e3cac |
8
.github/workflows/large-tests.yml
vendored
8
.github/workflows/large-tests.yml
vendored
@@ -6,12 +6,12 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
runs-on: self-hosted
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
@@ -23,6 +23,6 @@ jobs:
|
||||
- name: Self Attention KZG prove and verify large tests
|
||||
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_0_expects -- --include-ignored
|
||||
- name: mobilenet Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_3_expects -- --include-ignored
|
||||
- name: mobilenet KZG prove and verify large tests
|
||||
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_2_expects -- --include-ignored
|
||||
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_3_expects -- --include-ignored
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
|
||||
125
.github/workflows/rust.yml
vendored
125
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -198,10 +198,8 @@ jobs:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Install wasm runner
|
||||
run: cargo install wasm-server-runner
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -214,7 +212,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -231,13 +229,15 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
|
||||
- name: kzg inputs
|
||||
@@ -286,7 +286,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -312,7 +312,9 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -343,18 +345,15 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -414,32 +413,32 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 2
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
fuzz-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
@@ -448,7 +447,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -458,20 +457,20 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: fuzz tests (EVM)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
|
||||
# - name: fuzz tests
|
||||
# run: cargo nextest run --release --verbose tests::kzg_fuzz_ --test-threads 6
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -489,7 +488,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -506,7 +505,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -523,7 +522,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -533,7 +532,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -544,7 +543,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -566,7 +565,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install solc
|
||||
@@ -574,7 +573,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -590,7 +589,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -601,6 +600,8 @@ jobs:
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
@@ -611,27 +612,7 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
runs-on:
|
||||
large-self-hosted
|
||||
# Service containers to run with `container-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_USER: ubuntu
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
# needs: [build, library-tests, docs]
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
@@ -639,7 +620,7 @@ jobs:
|
||||
python-version: "3.9"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -649,7 +630,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
@@ -663,13 +644,13 @@ jobs:
|
||||
# # now dump the contents of the file into a file called kaggle.json
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
7
.github/workflows/wasm.yml
vendored
7
.github/workflows/wasm.yml
vendored
@@ -22,18 +22,15 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
|
||||
687
Cargo.lock
generated
687
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
20
Cargo.toml
20
Cargo.toml
@@ -15,9 +15,9 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
|
||||
halo2curves = { version = "0.1.0" }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
@@ -28,16 +28,19 @@ thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "ac/lookup-modularity" }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
|
||||
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
|
||||
indicatif = {version = "0.17.5", features = ["rayon"]}
|
||||
gag = { version = "1.0.0", default_features = false}
|
||||
instant = { version = "0.1" }
|
||||
@@ -51,9 +54,9 @@ plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
|
||||
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
@@ -158,6 +161,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidi
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
|
||||
@@ -64,8 +64,8 @@ More notebook tutorials can be found within `examples/notebooks`.
|
||||
|
||||
#### CLI
|
||||
Install the CLI
|
||||
```bash
|
||||
curl https://hub.ezkl.xyz/install_ezkl_cli.sh | bash
|
||||
``` shell
|
||||
curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh | bash
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-902e-20738ce0e9f0.mp4
|
||||
|
||||
@@ -121,13 +121,16 @@ fn runcnvrl(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
|
||||
&circuit, ¶ms, true,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
|
||||
@@ -90,13 +90,13 @@ fn rundot(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -94,13 +94,13 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
@@ -111,13 +112,13 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -3,6 +3,7 @@ use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -16,7 +17,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: (i128, i128) = (-8180, 8180);
|
||||
const BITS: Range = (-8180, 8180);
|
||||
static mut LEN: usize = 4;
|
||||
static mut K: usize = 16;
|
||||
|
||||
@@ -114,13 +115,13 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(k as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(k as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
|
||||
@@ -86,13 +86,13 @@ fn runsum(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -101,13 +101,16 @@ fn runsumpool(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
|
||||
&circuit, ¶ms, true,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
|
||||
@@ -84,13 +84,13 @@ fn runadd(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -83,13 +83,13 @@ fn runpow(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -76,13 +76,13 @@ fn runposeidon(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
@@ -14,7 +15,7 @@ use halo2_proofs::{
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
|
||||
const BITS: (i128, i128) = (-32768, 32768);
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
@@ -90,13 +91,13 @@ fn runrelu(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
|
||||
@@ -6,6 +6,7 @@ use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -489,6 +490,7 @@ pub fn runconv() {
|
||||
strategy,
|
||||
pi_for_real_prover,
|
||||
&mut transcript,
|
||||
params.n(),
|
||||
);
|
||||
assert!(verify.is_ok());
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elemenets :-D). \n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
|
||||
@@ -309,7 +309,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
|
||||
"print(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -338,7 +338,7 @@
|
||||
"\n",
|
||||
"def test_on_chain_data(res):\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = [int(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
|
||||
" data = [int(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
|
||||
"\n",
|
||||
" # Step 1: Prepare the data\n",
|
||||
" # Step 2: Prepare and compile the contract.\n",
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "gvQ5HL1bTDWF"
|
||||
},
|
||||
@@ -441,7 +441,9 @@
|
||||
"# Serialize calibration data into file:\n",
|
||||
"json.dump(data, open(cal_data_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\") # Optimize for resources"
|
||||
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
|
||||
"# You may want to increase the max logrows if accuracy is a concern\n",
|
||||
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -506,9 +508,8 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
@@ -563,7 +564,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@@ -695,7 +695,7 @@
|
||||
"formatted_output = \"[\"\n",
|
||||
"for i, value in enumerate(proof[\"instances\"]):\n",
|
||||
" for j, field_element in enumerate(value):\n",
|
||||
" onchain_input_array.append(ezkl.vecu64_to_felt(field_element))\n",
|
||||
" onchain_input_array.append(ezkl.string_to_felt(field_element))\n",
|
||||
" formatted_output += str(onchain_input_array[-1])\n",
|
||||
" if j != len(value) - 1:\n",
|
||||
" formatted_output += \", \"\n",
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"source": [
|
||||
"# kzg-ezkl\n",
|
||||
"\n",
|
||||
"Here's an example leveraging EZKL whereby the inputs to the model, and the model params themselves, are commited to using kzg-commitments inside a circuit.\n",
|
||||
"Here's an example leveraging EZKL whereby the inputs to the model, and the model params themselves, are committed to using kzg-commitments inside a circuit.\n",
|
||||
"\n",
|
||||
"In this setup:\n",
|
||||
"- the commitments are publicly known to the prover and verifier\n",
|
||||
@@ -166,7 +166,7 @@
|
||||
"Shoutouts: \n",
|
||||
"\n",
|
||||
"- [summa-solvency](https://github.com/summa-dev/summa-solvency) for their help with the poseidon hashing chip. \n",
|
||||
"- [timeofey](https://github.com/timoftime) for providing inspiration in our developement of the el-gamal encryption circuit in Halo2. "
|
||||
"- [timeofey](https://github.com/timoftime) for providing inspiration in our development of the el-gamal encryption circuit in Halo2. "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -300,13 +300,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
|
||||
"\n",
|
||||
"def setup(i):\n",
|
||||
" print(\"Setting up split model \"+str(i))\n",
|
||||
" # file names\n",
|
||||
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
|
||||
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
|
||||
@@ -342,12 +343,12 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" compress_selectors=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" assert res == True\n",
|
||||
" assert os.path.isfile(vk_path)\n",
|
||||
" assert os.path.isfile(pk_path)\n",
|
||||
" \n",
|
||||
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
|
||||
"\n",
|
||||
@@ -383,7 +384,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"for-aggr\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@@ -413,7 +413,6 @@
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" assert res == True\n",
|
||||
@@ -442,7 +441,7 @@
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -780,7 +780,7 @@
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
@@ -845,7 +845,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [5,6])\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -887,11 +887,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"id": "12YIcFr85X9-"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"spawning module 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"quotient_poly_degree 4\n",
|
||||
"n 262144\n",
|
||||
"extended_k 20\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
@@ -971,9 +988,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
||||
1
examples/notebooks/proof_aggr.json
Normal file
1
examples/notebooks/proof_aggr.json
Normal file
File diff suppressed because one or more lines are too long
@@ -302,7 +302,7 @@
|
||||
" assert res == True\n",
|
||||
" assert os.path.isfile(vk_path)\n",
|
||||
" assert os.path.isfile(pk_path)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
|
||||
"\n",
|
||||
@@ -330,14 +330,14 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"for-aggr\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(res)\n",
|
||||
" res_1_proof = res[\"proof\"]\n",
|
||||
" assert os.path.isfile(proof_path)\n",
|
||||
"\n",
|
||||
" # Verify the proof\n",
|
||||
" # # Verify the proof\n",
|
||||
" if i > 0:\n",
|
||||
" print(\"swapping commitments\")\n",
|
||||
" # swap the proof commitments if we are not the first model\n",
|
||||
@@ -356,12 +356,19 @@
|
||||
"\n",
|
||||
" res = ezkl.swap_proof_commitments(proof_path, witness_path)\n",
|
||||
" print(res)\n",
|
||||
" \n",
|
||||
" # load proof and then print \n",
|
||||
" proof = json.load(open(proof_path, 'r'))\n",
|
||||
" res_2_proof = proof[\"hex_proof\"]\n",
|
||||
" # show diff in hex strings\n",
|
||||
" print(res_1_proof)\n",
|
||||
" print(res_2_proof)\n",
|
||||
" assert res_1_proof == res_2_proof\n",
|
||||
"\n",
|
||||
" res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" assert res == True\n",
|
||||
@@ -439,7 +446,7 @@
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
@@ -122,8 +122,8 @@
|
||||
"# Loop through each element in the y tensor\n",
|
||||
"for e in y_input:\n",
|
||||
" # Apply the custom function and append the result to the list\n",
|
||||
" print(ezkl.float_to_vecu64(e,7))\n",
|
||||
" result.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 7)])[0])\n",
|
||||
" print(ezkl.float_to_string(e,7))\n",
|
||||
" result.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 7)])[0])\n",
|
||||
"\n",
|
||||
"y = y.unsqueeze(0)\n",
|
||||
"y = y.reshape(1, 9)\n",
|
||||
@@ -343,7 +343,7 @@
|
||||
"# we force the output to be 0 this corresponds to the set membership test being true -- and we set this to a fixed vis output\n",
|
||||
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
|
||||
"witness = json.load(open(witness_path, \"r\"))\n",
|
||||
"witness[\"outputs\"][0] = [[0, 0, 0, 0]]\n",
|
||||
"witness[\"outputs\"][0] = [\"0000000000000000000000000000000000000000000000000000000000000000\"]\n",
|
||||
"json.dump(witness, open(witness_path, \"w\"))\n",
|
||||
"\n",
|
||||
"witness = json.load(open(witness_path, \"r\"))\n",
|
||||
@@ -353,7 +353,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" witness_path = witness_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@@ -520,4 +519,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"source": [
|
||||
"## EZKL Jupyter Notebook Demo \n",
|
||||
"\n",
|
||||
"Here we demonstrate how to use the EZKL package to run a publicly known / committted to network on some private data, producing a public output.\n"
|
||||
"Here we demonstrate how to use the EZKL package to run a publicly known / committed to network on some private data, producing a public output.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -210,7 +210,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
||||
@@ -126,7 +126,7 @@
|
||||
"# Loop through each element in the y tensor\n",
|
||||
"for e in user_preimages:\n",
|
||||
" # Apply the custom function and append the result to the list\n",
|
||||
" users.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 0)])[0])\n",
|
||||
" users.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 0)])[0])\n",
|
||||
"\n",
|
||||
"users_t = torch.tensor(user_preimages)\n",
|
||||
"users_t = users_t.reshape(1, 6)\n",
|
||||
@@ -303,7 +303,7 @@
|
||||
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
|
||||
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
|
||||
"witness = json.load(open(witness_path, \"r\"))\n",
|
||||
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
|
||||
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
|
||||
"json.dump(witness, open(witness_path, \"w\"))"
|
||||
]
|
||||
},
|
||||
@@ -417,7 +417,7 @@
|
||||
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
|
||||
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
|
||||
"witness = json.load(open(witness_path, \"r\"))\n",
|
||||
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
|
||||
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
|
||||
"json.dump(witness, open(witness_path, \"w\"))\n"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -633,7 +633,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -154,7 +154,7 @@
|
||||
"source": [
|
||||
"## Create a neural net to verify the execution of the tic tac toe model\n",
|
||||
"\n",
|
||||
"1. Given the data generated above classify whether the tic tac toe games are valid. This approach uses a binary classification as the tic tac toe state space is fairly small. For larger state spaces we will want to use anomaly detection based approachs"
|
||||
"1. Given the data generated above classify whether the tic tac toe games are valid. This approach uses a binary classification as the tic tac toe state space is fairly small. For larger state spaces, we will want to use anomaly detection based approaches."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -520,7 +520,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -636,7 +636,8 @@
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -237,7 +237,7 @@
|
||||
"\n",
|
||||
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"ezkl.calibrate_settings(\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
|
||||
"res = ezkl.get_srs(settings_filename)\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
|
||||
"\n",
|
||||
@@ -255,7 +255,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
@@ -451,7 +451,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -25,17 +25,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"voice_data_dir: .\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import os\n",
|
||||
@@ -43,7 +35,7 @@
|
||||
"\n",
|
||||
"voice_data_dir = os.environ.get('VOICE_DATA_DIR')\n",
|
||||
"\n",
|
||||
"# if is none set to \"\" \n",
|
||||
"# if is none set to \"\"\n",
|
||||
"if voice_data_dir is None:\n",
|
||||
" voice_data_dir = \"\"\n",
|
||||
"\n",
|
||||
@@ -637,7 +629,7 @@
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\")\n",
|
||||
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")\n"
|
||||
]
|
||||
@@ -908,7 +900,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"import torch\n",
|
||||
"import math\n",
|
||||
"\n",
|
||||
"# these are constatns for the rotation\n",
|
||||
"# these are constants for the rotation\n",
|
||||
"phi = torch.tensor(5 * math.pi / 180)\n",
|
||||
"s = torch.sin(phi)\n",
|
||||
"c = torch.cos(phi)\n",
|
||||
@@ -503,11 +503,11 @@
|
||||
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
|
||||
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
|
||||
"\n",
|
||||
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][0], out_scale)\n",
|
||||
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][1], out_scale)\n",
|
||||
"arrow_x = ezkl.string_to_float(witness['outputs'][0][0], out_scale)\n",
|
||||
"arrow_y = ezkl.string_to_float(witness['outputs'][0][1], out_scale)\n",
|
||||
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
|
||||
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][2], out_scale)\n",
|
||||
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][3], out_scale)\n",
|
||||
"arrow_x = ezkl.string_to_float(witness['outputs'][0][2], out_scale)\n",
|
||||
"arrow_y = ezkl.string_to_float(witness['outputs'][0][3], out_scale)\n",
|
||||
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
|
||||
]
|
||||
}
|
||||
|
||||
39
examples/onnx/1l_tiny_div/gen.py
Normal file
39
examples/onnx/1l_tiny_div/gen.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
|
||||
class Circuit(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Circuit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x/ 10000
|
||||
|
||||
|
||||
circuit = Circuit()
|
||||
|
||||
|
||||
x = torch.empty(1, 8).random_(0, 2)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/1l_tiny_div/input.json
Normal file
1
examples/onnx/1l_tiny_div/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}
|
||||
BIN
examples/onnx/1l_tiny_div/network.onnx
Normal file
BIN
examples/onnx/1l_tiny_div/network.onnx
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
3074
examples/test_failure_aggr_proof.json
Normal file
3074
examples/test_failure_aggr_proof.json
Normal file
File diff suppressed because one or more lines are too long
3208
examples/test_failure_proof.json
Normal file
3208
examples/test_failure_proof.json
Normal file
File diff suppressed because one or more lines are too long
170
install_ezkl_cli.sh
Normal file
170
install_ezkl_cli.sh
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
|
||||
BASE_DIR=${XDG_CONFIG_HOME:-$HOME}
|
||||
EZKL_DIR=${EZKL_DIR-"$BASE_DIR/.ezkl"}
|
||||
|
||||
# Create the .ezkl bin directory if it doesn't exit
|
||||
mkdir -p $EZKL_DIR
|
||||
|
||||
# Store the correct profile file (i.e. .profile for bash or .zshenv for ZSH).
|
||||
case $SHELL in
|
||||
*/zsh)
|
||||
PROFILE=${ZDOTDIR-"$HOME"}/.zshenv
|
||||
PREF_SHELL=zsh
|
||||
;;
|
||||
*/bash)
|
||||
PROFILE=$HOME/.bashrc
|
||||
PREF_SHELL=bash
|
||||
;;
|
||||
*/fish)
|
||||
PROFILE=$HOME/.config/fish/config.fish
|
||||
PREF_SHELL=fish
|
||||
;;
|
||||
*/ash)
|
||||
PROFILE=$HOME/.profile
|
||||
PREF_SHELL=ash
|
||||
;;
|
||||
*)
|
||||
echo "NOTICE: Shell could not be detected, you will need to manually add ${EZKL_DIR} to your PATH."
|
||||
esac
|
||||
|
||||
# Check for non standard installation of ezkl
|
||||
if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; then
|
||||
echo "ezkl is installed in a non-standard directory, $(which ezkl). To use the automated installer, remove the existing ezkl from path to prevent conflicts"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
|
||||
# Add the ezkl directory to the path and ensure the old PATH variables remain.
|
||||
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
|
||||
fi
|
||||
|
||||
# Install latest ezkl version
|
||||
# Get the right release URL
|
||||
if [ -z "$1" ]
|
||||
then
|
||||
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/latest"
|
||||
echo "No version tags provided, installing the latest ezkl version"
|
||||
else
|
||||
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/tags/$1"
|
||||
echo "Installing ezkl version $1"
|
||||
fi
|
||||
|
||||
PLATFORM=""
|
||||
case "$(uname -s)" in
|
||||
|
||||
Darwin*)
|
||||
PLATFORM="macos"
|
||||
;;
|
||||
|
||||
Linux*Microsoft*)
|
||||
PLATFORM="linux"
|
||||
;;
|
||||
|
||||
Linux*)
|
||||
PLATFORM="linux"
|
||||
;;
|
||||
|
||||
CYGWIN*|MINGW*|MINGW32*|MSYS*)
|
||||
PLATFORM="windows-msvc"
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Platform is not supported. If you would need support for the platform please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# Check arch
|
||||
ARCHITECTURE="$(uname -m)"
|
||||
if [ "${ARCHITECTURE}" = "x86_64" ]; then
|
||||
# Redirect stderr to /dev/null to avoid printing errors if non Rosetta.
|
||||
if [ "$(sysctl -n sysctl.proc_translated 2>/dev/null)" = "1" ]; then
|
||||
ARCHITECTURE="arm64" # Rosetta.
|
||||
else
|
||||
ARCHITECTURE="amd64" # Intel.
|
||||
fi
|
||||
elif [ "${ARCHITECTURE}" = "arm64" ] ||[ "${ARCHITECTURE}" = "aarch64" ]; then
|
||||
ARCHITECTURE="aarch64" # Arm.
|
||||
elif [ "${ARCHITECTURE}" = "amd64" ]; then
|
||||
ARCHITECTURE="amd64" # Amd
|
||||
else
|
||||
echo "Architecture is not supported. If you would need support for the architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Remove existing ezkl
|
||||
echo "Removing old ezkl binary if it exists"
|
||||
[ -e file ] && rm file
|
||||
|
||||
# download the release and unpack the right tarball
|
||||
if [ "$PLATFORM" == "windows-msvc" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-windows-msvc.tar.gz")
|
||||
|
||||
echo "Downloading package"
|
||||
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
|
||||
|
||||
echo "Unpacking package"
|
||||
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz" -C "$EZKL_DIR"
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
|
||||
|
||||
elif [ "$PLATFORM" == "macos" ]; then
|
||||
if [ "$ARCHITECTURE" == "aarch64" ] || [ "$ARCHITECTURE" == "arm64" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos-aarch64.tar.gz")
|
||||
|
||||
echo "Downloading package"
|
||||
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
|
||||
|
||||
echo "Unpacking package"
|
||||
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz" -C "$EZKL_DIR"
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
|
||||
|
||||
else
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
|
||||
|
||||
echo "Downloading package"
|
||||
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
|
||||
|
||||
echo "Unpacking package"
|
||||
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz" -C "$EZKL_DIR"
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
|
||||
|
||||
fi
|
||||
|
||||
elif [ "$PLATFORM" == "linux" ]; then
|
||||
if [ "${ARCHITECTURE}" = "amd64" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
|
||||
|
||||
echo "Downloading package"
|
||||
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
|
||||
|
||||
echo "Unpacking package"
|
||||
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz" -C "$EZKL_DIR"
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
|
||||
|
||||
else
|
||||
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "Platform and Architecture is not supported. If you would need support for the platform and architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo && echo "Successfully downloaded ezkl at ${EZKL_DIR}"
|
||||
echo "We detected that your preferred shell is ${PREF_SHELL} and added ezkl to PATH. Run 'source ${PROFILE}' or start a new terminal session to use ezkl."
|
||||
@@ -11,8 +11,8 @@ use ezkl::execute::run;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ezkl::logger::init_logger;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{error, info};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(feature = "icicle")]
|
||||
@@ -25,6 +25,7 @@ use std::error::Error;
|
||||
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let args = Cli::parse();
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
@@ -32,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
let res = run(args.command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
@@ -44,7 +45,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn main() {}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
fn banner() {
|
||||
let ell: Vec<&str> = vec![
|
||||
"for Neural Networks",
|
||||
|
||||
@@ -41,7 +41,7 @@ pub struct KZGChip {
|
||||
}
|
||||
|
||||
impl KZGChip {
|
||||
/// Returns the number of inputs to the hash function
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
pub fn commit(
|
||||
message: Vec<Fp>,
|
||||
degree: u32,
|
||||
@@ -219,7 +219,7 @@ mod tests {
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,6 +240,6 @@ mod tests {
|
||||
message: message.into(),
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ use halo2_proofs::{
|
||||
Instance, Selector, TableColumn,
|
||||
},
|
||||
};
|
||||
use log::{trace, warn};
|
||||
use log::{debug, trace};
|
||||
|
||||
/// A simple [`FloorPlanner`] that performs minimal optimizations.
|
||||
#[derive(Debug)]
|
||||
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
|
||||
Error::Synthesis
|
||||
})?;
|
||||
if !self.regions.contains_key(&index) {
|
||||
warn!("spawning module {}", index)
|
||||
debug!("spawning module {}", index)
|
||||
};
|
||||
self.current_module = index;
|
||||
}
|
||||
|
||||
@@ -499,7 +499,7 @@ mod tests {
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -518,7 +518,7 @@ mod tests {
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -551,7 +551,7 @@ mod tests {
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -573,6 +573,6 @@ mod tests {
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify_par(), Ok(()))
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,8 +125,8 @@ impl BaseOp {
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::Range { .. } => 1,
|
||||
BaseOp::IsZero => 1,
|
||||
BaseOp::IsBoolean => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,10 +16,14 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::ops::base::BaseOp,
|
||||
circuit::{table::Table, utils},
|
||||
circuit::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
@@ -58,6 +62,22 @@ pub enum CheckMode {
|
||||
UNSAFE,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CheckMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CheckMode::SAFE => write!(f, "safe"),
|
||||
CheckMode::UNSAFE => write!(f, "unsafe"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CheckMode {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for CheckMode {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -80,6 +100,19 @@ pub struct Tolerance {
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
@@ -176,6 +209,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
///
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -194,7 +231,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
lookup_index: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
check_mode: CheckMode::SAFE,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -267,9 +306,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
BaseOp::IsZero => vec![qis[1].clone()],
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
@@ -325,11 +375,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Self {
|
||||
selectors,
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
inputs: inputs.to_vec(),
|
||||
lookup_input: VarTensor::Empty,
|
||||
lookup_output: VarTensor::Empty,
|
||||
lookup_index: VarTensor::Empty,
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
output: output.clone(),
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
@@ -344,7 +396,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
input: &VarTensor,
|
||||
output: &VarTensor,
|
||||
index: &VarTensor,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
nl: &LookupOp,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
@@ -482,6 +534,75 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_range_check(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
input: &VarTensor,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check =
|
||||
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
let single_col_sel = cs.complex_selector();
|
||||
|
||||
cs.lookup("", |cs| {
|
||||
let mut res = vec![];
|
||||
let sel = cs.query_selector(single_col_sel);
|
||||
|
||||
let input_query = match &input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let default_x = range_check.get_first_element();
|
||||
|
||||
let not_sel = Expression::Constant(F::ONE) - sel.clone();
|
||||
|
||||
res.extend([(
|
||||
sel.clone() * input_query.clone()
|
||||
+ not_sel.clone() * Expression::Constant(default_x),
|
||||
range_check.input,
|
||||
)]);
|
||||
|
||||
res
|
||||
});
|
||||
selectors.insert((range, x, y), single_col_sel);
|
||||
}
|
||||
}
|
||||
self.range_check_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
self.lookup_input = input.clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
for (i, table) in self.tables.values_mut().enumerate() {
|
||||
@@ -500,6 +621,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// layout_range_checks must be called before layout.
|
||||
pub fn layout_range_checks(
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
for range_check in self.range_checks.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
range_check.layout(layouter)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Assigns variables to the regions created when calling `configure`.
|
||||
/// # Arguments
|
||||
/// * `values` - The explicit values to the operations.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{self, layouts, utils, Tolerance},
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
// import run args from model
|
||||
|
||||
@@ -13,6 +13,15 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
@@ -59,14 +68,6 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
num_classes: usize,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
@@ -74,7 +75,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::ScatterElements { .. } => vec![0, 2],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
@@ -87,142 +88,42 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let (res, intermediate_lookups) = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => {
|
||||
let res = tensor::ops::max_axes(&x, axes)?;
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::ReduceMin { axes, .. } => {
|
||||
let res = tensor::ops::min_axes(&x, axes)?;
|
||||
let min_plus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(min(x + 1) - x)
|
||||
let inter_1 = (min_plus_one - x.clone())?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::ReduceArgMax { dim } => {
|
||||
let res = tensor::ops::argmax_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
let res = tensor::ops::argmin_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
}
|
||||
HybridOp::TopK { dim, k, largest } => {
|
||||
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
|
||||
|
||||
let mut inter_equals = x
|
||||
.clone()
|
||||
.into_iter()
|
||||
.flat_map(|elem| {
|
||||
tensor::ops::equals(&res, &vec![elem].into_iter().into())
|
||||
.unwrap()
|
||||
.1
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// sort in descending order and take pairwise differences
|
||||
inter_equals.push(
|
||||
x.into_iter()
|
||||
.sorted()
|
||||
.tuple_windows()
|
||||
.map(|(a, b)| b - a)
|
||||
.into(),
|
||||
);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let src = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
|
||||
let src = inputs[2].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => {
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(
|
||||
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
vec![inter_1, inter_2],
|
||||
)
|
||||
}
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -234,10 +135,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
(
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
|
||||
vec![],
|
||||
)
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
@@ -264,14 +162,26 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups,
|
||||
})
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -306,8 +216,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
HybridOp::TopK { k, dim, largest } => {
|
||||
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
|
||||
}
|
||||
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
|
||||
}
|
||||
@@ -335,6 +243,55 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
*kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(input_scale.0 as i128),
|
||||
i128_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(denom.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Div { denom: *denom },
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
@@ -342,26 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
layouts::gather(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
@@ -422,86 +360,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { .. } => 2 * in_scales[0],
|
||||
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
match self {
|
||||
HybridOp::ReduceMax { .. }
|
||||
| HybridOp::ReduceMin { .. }
|
||||
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
|
||||
HybridOp::Softmax { scale, .. } => {
|
||||
vec![
|
||||
LookupOp::Exp { scale: *scale },
|
||||
LookupOp::Recip {
|
||||
scale: scale.0.powf(2.0).into(),
|
||||
},
|
||||
]
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let mut lookups = vec![];
|
||||
if tol.val > 0.0 {
|
||||
let scale_squared = tol.scale.0.powf(2.0);
|
||||
lookups.extend([
|
||||
LookupOp::Recip {
|
||||
scale: scale_squared.into(),
|
||||
},
|
||||
LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32((tol.val * scale_squared) / 100.0),
|
||||
},
|
||||
]);
|
||||
}
|
||||
lookups
|
||||
}
|
||||
HybridOp::Greater { .. } | HybridOp::Less { .. } => {
|
||||
vec![LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32(0.),
|
||||
}]
|
||||
}
|
||||
HybridOp::GreaterEqual { .. } | HybridOp::LessEqual { .. } => {
|
||||
vec![LookupOp::GreaterThanEqual {
|
||||
a: circuit::utils::F32(0.),
|
||||
}]
|
||||
}
|
||||
HybridOp::TopK { .. } => {
|
||||
vec![
|
||||
LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32(0.),
|
||||
},
|
||||
LookupOp::KroneckerDelta,
|
||||
]
|
||||
}
|
||||
HybridOp::Gather {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::GatherElements {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::ScatterElements {
|
||||
constant_idx: None, ..
|
||||
}
|
||||
| HybridOp::Equals { .. } => {
|
||||
vec![LookupOp::KroneckerDelta]
|
||||
}
|
||||
HybridOp::ReduceArgMax { .. } | HybridOp::ReduceArgMin { .. } => {
|
||||
vec![LookupOp::ReLU, LookupOp::KroneckerDelta]
|
||||
}
|
||||
HybridOp::SumPool {
|
||||
kernel_shape,
|
||||
normalized: true,
|
||||
..
|
||||
} => {
|
||||
vec![LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
}]
|
||||
}
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
@@ -18,8 +18,11 @@ use super::{
|
||||
region::RegionCtx,
|
||||
};
|
||||
use crate::{
|
||||
circuit::{ops::base::BaseOp, utils},
|
||||
fieldutils::i128_to_felt,
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
utils::{self},
|
||||
},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
@@ -51,6 +54,125 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
|
||||
total_len
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
pub fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
div: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(div) / 2;
|
||||
|
||||
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
|
||||
divisor.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
|
||||
region.increment(divisor.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), divisor.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), input.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// recip accumulated layout
|
||||
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
tensor::ops::nonlinearities::recip(
|
||||
&input_evals,
|
||||
felt_to_i128(input_scale) as f64,
|
||||
felt_to_i128(output_scale) as f64,
|
||||
)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
&[input.len()],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
|
||||
// this is now of scale 2 * scale
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), input.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
/// Dot product accumulated layout
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -92,7 +214,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (i, input) in values.iter_mut().enumerate() {
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_with_duplication(
|
||||
&config.inputs[i],
|
||||
@@ -648,14 +770,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_input = region.assign(&config.inputs[0], &input)?;
|
||||
|
||||
// now assert all elems are 0 or 1
|
||||
let assigned_output = region.assign(&config.inputs[1], &output)?;
|
||||
if !region.is_dummy() {
|
||||
for i in 0..assigned_output.len() {
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
|
||||
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
|
||||
|
||||
let sum = sum(config, region, &[assigned_output.clone()])?;
|
||||
@@ -1027,7 +1142,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1096,7 +1211,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1560,10 +1675,28 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff_inverse = diff.inverse()?;
|
||||
let product_diff_and_invert =
|
||||
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
|
||||
|
||||
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
|
||||
// constant of 1
|
||||
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
||||
ones.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
|
||||
Ok(res)
|
||||
// subtract
|
||||
let output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[ones.into(), product_diff_and_invert],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Xor boolean operation
|
||||
@@ -1627,21 +1760,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into();
|
||||
|
||||
// make sure mask is boolean
|
||||
let assigned_mask = region.assign(&config.inputs[1], mask)?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_mask.len())
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_mask.len());
|
||||
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
|
||||
|
||||
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
|
||||
|
||||
@@ -1735,17 +1854,15 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
let shape = &res[0].dims()[2..];
|
||||
let mut last_elem = res[1..]
|
||||
.iter()
|
||||
.fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?;
|
||||
.try_fold(res[0].clone(), |acc, elem| acc.concat(elem.clone()))?;
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = nonlinearity(
|
||||
last_elem = div(
|
||||
config,
|
||||
region,
|
||||
&[last_elem],
|
||||
&LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
},
|
||||
F::from((kernel_shape.0 * kernel_shape.1) as u64),
|
||||
)?;
|
||||
}
|
||||
Ok(last_elem)
|
||||
@@ -1837,15 +1954,6 @@ pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std:
|
||||
)));
|
||||
}
|
||||
|
||||
if has_bias {
|
||||
let bias = &inputs[2];
|
||||
if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) {
|
||||
return Err(Box::new(TensorError::DimMismatch(
|
||||
"deconv bias".to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]);
|
||||
|
||||
let null_val = ValType::Constant(F::ZERO);
|
||||
@@ -2251,18 +2359,60 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = region.assign(&config.inputs[1], &values[0])?;
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
// get zero constants indices
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -2270,7 +2420,6 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
region.increment(output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -2313,6 +2462,52 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// layout for range check.
|
||||
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
range: &crate::circuit::table::Range,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_range_check(*range)?;
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
|
||||
let x = values[0].clone();
|
||||
|
||||
let w = region.assign(&config.lookup_input, &x)?;
|
||||
|
||||
let assigned_len = x.len();
|
||||
|
||||
let is_dummy = region.is_dummy();
|
||||
|
||||
if !is_dummy {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config
|
||||
.lookup_input
|
||||
.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.range_check_selectors.get(&(*range, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
let elapsed = timer.elapsed();
|
||||
trace!(
|
||||
"range check {:?} layout took {:?}, row: {:?}",
|
||||
range,
|
||||
elapsed,
|
||||
region.row()
|
||||
);
|
||||
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
/// layout for nonlinearity check.
|
||||
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2320,6 +2515,8 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
nl: &LookupOp,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_lookup(nl.clone(), values)?;
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
|
||||
@@ -2401,22 +2598,6 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// mean function layout
|
||||
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let x = &values[0];
|
||||
|
||||
let sum_x = sum(config, region, &[x.clone()])?;
|
||||
let nl = LookupOp::Div {
|
||||
denom: utils::F32((scale * x.len()) as f32),
|
||||
};
|
||||
nonlinearity(config, region, &[sum_x], &nl)
|
||||
}
|
||||
|
||||
/// Argmax
|
||||
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2529,24 +2710,8 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
)?;
|
||||
// relu(x - max(x - 1))
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
// constraining relu(x - max(x - 1)) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
// sum(relu(x - max(x - 1)))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2557,13 +2722,7 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
// constraining 1 - sum(relu(x - max(x - 1))) = 0
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
|
||||
Ok(assigned_max_val)
|
||||
}
|
||||
@@ -2608,23 +2767,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
// relu(min(x + 1) - x)
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
// constraining relu(min(x + 1) - x) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
// sum(relu(min(x + 1) - x))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2635,14 +2779,8 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
let relu_one_minus_sum_relu =
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
// constraining product to 0
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
|
||||
Ok(assigned_min_val)
|
||||
}
|
||||
@@ -2789,7 +2927,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
&[denom],
|
||||
// we set to input scale + output_scale so the output scale is output)scale
|
||||
&LookupOp::Recip {
|
||||
scale: scale.0.powf(2.0).into(),
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)?;
|
||||
|
||||
@@ -2814,22 +2953,34 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
return enforce_equality(config, region, values);
|
||||
}
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
|
||||
values[0] = region.assign(&config.inputs[0], &values[0])?;
|
||||
values[1] = region.assign(&config.inputs[1], &values[1])?;
|
||||
let total_assigned_0 = values[0].len();
|
||||
let total_assigned_1 = values[1].len();
|
||||
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
|
||||
region.increment(total_assigned);
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
let scale_squared = scale.0.powf(2.0);
|
||||
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
|
||||
let recip = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
&LookupOp::Recip {
|
||||
scale: scale_squared.into(),
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)?;
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
let scale_squared = scale.0 * scale.0;
|
||||
|
||||
// Use the greater than look up table to check if the percent error is within the tolerance for upper bound
|
||||
let tol = tol / 100.0;
|
||||
let upper_bound = nonlinearity(
|
||||
@@ -2857,15 +3008,8 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Add the lower_bound and upper_bound
|
||||
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
|
||||
|
||||
// Assign the sum tensor to the inputs
|
||||
region.assign(&config.inputs[1], &sum)?;
|
||||
|
||||
// Constrain the sum to be all zeros
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(sum.len());
|
||||
is_zero_identity(config, region, &[sum.clone()], false)?;
|
||||
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::{multiplier_to_scale, scale_to_multiplier},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -17,47 +17,117 @@ use halo2curves::ff::PrimeField;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Abs,
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ReLU,
|
||||
Max { scale: utils::F32, a: utils::F32 },
|
||||
Min { scale: utils::F32, a: utils::F32 },
|
||||
Ceil { scale: utils::F32 },
|
||||
Floor { scale: utils::F32 },
|
||||
Round { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
Recip { scale: utils::F32 },
|
||||
LeakyReLU { slope: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
GreaterThan { a: utils::F32 },
|
||||
LessThan { a: utils::F32 },
|
||||
GreaterThanEqual { a: utils::F32 },
|
||||
LessThanEqual { a: utils::F32 },
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
GreaterThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
GreaterThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
Sign,
|
||||
KroneckerDelta,
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> (i128, i128) {
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i128;
|
||||
(-range, range)
|
||||
@@ -120,7 +190,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
&x,
|
||||
f32::from(*scale).into(),
|
||||
)),
|
||||
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
|
||||
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
@@ -150,10 +227,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
@@ -169,11 +243,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Sign => "SIGN".into(),
|
||||
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
|
||||
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
|
||||
LookupOp::LessThan { .. } => "LESS_THAN".into(),
|
||||
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
|
||||
LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale),
|
||||
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
|
||||
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
|
||||
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
@@ -220,12 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { scale } => {
|
||||
let mut out_scale = inputs_scale[0];
|
||||
out_scale +=
|
||||
multiplier_to_scale(scale.0 as f64 / scale_to_multiplier(out_scale).powf(2.0));
|
||||
out_scale
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::Sign
|
||||
| LookupOp::GreaterThan { .. }
|
||||
| LookupOp::LessThan { .. }
|
||||
@@ -237,10 +312,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
vec![self.clone()]
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ pub mod region;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
@@ -55,11 +54,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns the lookups required by the operation.
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
/// Returns true if the operation is an input.
|
||||
fn is_input(&self) -> bool {
|
||||
false
|
||||
@@ -183,7 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -206,6 +199,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
true,
|
||||
)?))
|
||||
}
|
||||
_ => Ok(Some(super::layouts::identity(
|
||||
@@ -308,10 +302,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -9,6 +10,14 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -33,7 +42,9 @@ pub enum PolyOp {
|
||||
Sub,
|
||||
Neg,
|
||||
Mult,
|
||||
Identity,
|
||||
Identity {
|
||||
out_scale: Option<crate::Scale>,
|
||||
},
|
||||
Reshape(Vec<usize>),
|
||||
MoveAxis {
|
||||
source: usize,
|
||||
@@ -79,13 +90,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
PolyOp::Resize { .. } => "RESIZE".into(),
|
||||
PolyOp::Iff => "IFF".into(),
|
||||
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
|
||||
PolyOp::Identity => "IDENTITY".into(),
|
||||
PolyOp::Identity { out_scale } => {
|
||||
format!("IDENTITY (out_scale={:?})", out_scale)
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
@@ -135,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity => Ok(inputs[0].clone()),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
@@ -199,14 +214,36 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult {
|
||||
output: res,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
Ok(ForwardResult { output: res })
|
||||
}
|
||||
|
||||
fn layout(
|
||||
@@ -237,7 +274,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
PolyOp::Neg => layouts::neg(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Einsum { equation } => layouts::einsum(config, region, &values, equation)?,
|
||||
PolyOp::Einsum { equation } => layouts::einsum(config, region, values, equation)?,
|
||||
PolyOp::Sum { axes } => {
|
||||
layouts::sum_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
@@ -247,6 +284,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -264,7 +321,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
@@ -290,12 +347,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MultiBroadcastTo { .. } => in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Neg => in_scales[0],
|
||||
PolyOp::MoveAxis { .. } => in_scales[0],
|
||||
PolyOp::Downsample { .. } => in_scales[0],
|
||||
PolyOp::Resize { .. } => in_scales[0],
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
let mut scale = in_scales[0];
|
||||
@@ -327,9 +379,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
output_scale
|
||||
}
|
||||
PolyOp::Add => {
|
||||
let mut scale_a = 0;
|
||||
let scale_b = in_scales[0];
|
||||
scale_a += in_scales[1];
|
||||
let scale_a = in_scales[0];
|
||||
let scale_b = in_scales[1];
|
||||
assert_eq!(scale_a, scale_b);
|
||||
scale_a
|
||||
}
|
||||
@@ -339,26 +390,23 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
scale += in_scales[1];
|
||||
scale
|
||||
}
|
||||
PolyOp::Identity => in_scales[0],
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pad(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Pack(_, _) => in_scales[0],
|
||||
PolyOp::GlobalSumPool => in_scales[0],
|
||||
PolyOp::Concat { axis: _ } => in_scales[0],
|
||||
PolyOp::Slice { .. } => in_scales[0],
|
||||
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
if matches!(
|
||||
self,
|
||||
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
|
||||
) {
|
||||
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
|
||||
vec![0, 1]
|
||||
} else if matches!(self, PolyOp::Iff) {
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
plonk::{Error, Selector},
|
||||
@@ -7,9 +10,16 @@ use halo2curves::ff::PrimeField;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashSet,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
@@ -56,6 +66,10 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -75,6 +89,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -90,6 +108,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +126,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,8 +137,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
pub fn new_dummy_with_constants(
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
constants: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -120,7 +148,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: constants,
|
||||
total_constants,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +192,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
|
||||
/// Create a new region context per loop iteration
|
||||
/// hacky but it works
|
||||
|
||||
pub fn dummy_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
@@ -170,6 +203,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
@@ -177,12 +214,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
let starting_constants = constants.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_constants(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -195,6 +236,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
res
|
||||
})
|
||||
.map_err(|e| {
|
||||
@@ -203,7 +252,56 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
})?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
|
||||
})?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(
|
||||
&mut self,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if range.0 > range.1 {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
}
|
||||
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -212,15 +310,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.region.is_none()
|
||||
}
|
||||
|
||||
/// duplicate_dummy
|
||||
pub fn duplicate_dummy(&self) -> Self {
|
||||
Self {
|
||||
region: None,
|
||||
linear_coord: self.linear_coord,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
row: self.row,
|
||||
total_constants: self.total_constants,
|
||||
}
|
||||
/// add used lookup
|
||||
pub fn add_used_lookup(
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
@@ -238,6 +341,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
}
|
||||
|
||||
/// get used range checks
|
||||
pub fn used_range_checks(&self) -> HashSet<Range> {
|
||||
self.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i128 {
|
||||
self.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i128 {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
|
||||
@@ -19,6 +19,9 @@ use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
@@ -91,7 +94,7 @@ pub struct Table<F: PrimeField> {
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: (i128, i128),
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
@@ -129,7 +132,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
|
||||
// double it to be safe
|
||||
let range_len = range.1 - range.0;
|
||||
// number of cols needed to store the range
|
||||
@@ -141,7 +144,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
range: (i128, i128),
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
@@ -257,3 +260,86 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Halo2 range check column
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RangeCheck<F: PrimeField> {
|
||||
/// Input to table.
|
||||
pub input: TableColumn,
|
||||
/// selector cn
|
||||
pub selector_constructor: SelectorConstructor<F>,
|
||||
/// Flags if table has been previously assigned to.
|
||||
pub is_assigned: bool,
|
||||
/// Number of bits used in lookup table.
|
||||
pub range: Range,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self) -> F {
|
||||
i128_to_felt(self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
|
||||
let inputs = cs.lookup_table_column();
|
||||
|
||||
RangeCheck {
|
||||
input: inputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(2),
|
||||
range,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
layouter.assign_table(
|
||||
|| "range check table",
|
||||
|mut table| {
|
||||
let _ = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(row_offset, input)| {
|
||||
table.assign_cell(
|
||||
|| format!("rc_i_col row {}", row_offset),
|
||||
self.input,
|
||||
row_offset,
|
||||
|| Value::known(*input),
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ mod matmul {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ mod matmul_col_overflow_double_col {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,14 +239,14 @@ mod matmul_col_overflow {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow_double_col {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -327,7 +327,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
MatmulCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
@@ -349,8 +349,13 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -361,7 +366,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -441,7 +446,7 @@ mod matmul_col_ultra_overflow {
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
MatmulCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
@@ -463,8 +468,13 @@ mod matmul_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -543,7 +553,7 @@ mod dot {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -620,7 +630,7 @@ mod dot_col_overflow_triple_col {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -693,7 +703,7 @@ mod dot_col_overflow {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -762,7 +772,7 @@ mod sum {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -832,7 +842,7 @@ mod sum_col_overflow_double_col {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -901,7 +911,7 @@ mod sum_col_overflow {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -994,7 +1004,7 @@ mod composition {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1095,7 +1105,7 @@ mod conv {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1133,14 +1143,14 @@ mod conv {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1240,7 +1250,7 @@ mod conv_col_ultra_overflow {
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
@@ -1262,8 +1272,13 @@ mod conv_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1275,7 +1290,7 @@ mod conv_col_ultra_overflow {
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1390,7 +1405,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
@@ -1412,8 +1427,13 @@ mod conv_relu_col_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -1484,7 +1504,7 @@ mod add_w_shape_casting {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1551,7 +1571,7 @@ mod add {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1618,7 +1638,7 @@ mod add_with_overflow {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1727,7 +1747,7 @@ mod add_with_overflow_and_poseidon {
|
||||
|
||||
let prover =
|
||||
MockProver::run(K as u32, &circuit, vec![vec![commitment_a, commitment_b]]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1822,7 +1842,7 @@ mod sub {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1889,7 +1909,7 @@ mod mult {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1954,7 +1974,7 @@ mod pow {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2023,7 +2043,7 @@ mod pack {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2116,7 +2136,7 @@ mod matmul_relu {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2154,7 +2174,7 @@ mod rangecheckpercent {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let scale = utils::F32(SCALE.pow(2) as f32);
|
||||
let scale = utils::F32(SCALE as f32);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
@@ -2162,11 +2182,12 @@ mod rangecheckpercent {
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// set up a new GreaterThan and Recip tables
|
||||
let nl = &LookupOp::GreaterThan {
|
||||
a: circuit::utils::F32((RANGE * scale.0) / 100.0),
|
||||
a: circuit::utils::F32((RANGE * SCALE.pow(2) as f32) / 100.0),
|
||||
};
|
||||
config
|
||||
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
@@ -2175,7 +2196,10 @@ mod rangecheckpercent {
|
||||
&a,
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip { scale },
|
||||
&LookupOp::Recip {
|
||||
input_scale: scale,
|
||||
output_scale: scale,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
@@ -2222,7 +2246,7 @@ mod rangecheckpercent {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
{
|
||||
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
|
||||
@@ -2233,7 +2257,7 @@ mod rangecheckpercent {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
|
||||
// Unsuccessful case
|
||||
@@ -2328,7 +2352,7 @@ mod relu {
|
||||
};
|
||||
|
||||
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2339,7 +2363,7 @@ mod lookup_ultra_overflow {
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
poly::commitment::ParamsProver,
|
||||
poly::commitment::{Params, ParamsProver},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -2421,7 +2445,7 @@ mod lookup_ultra_overflow {
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ReLUCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
@@ -2443,8 +2467,13 @@ mod lookup_ultra_overflow {
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -2511,7 +2540,8 @@ mod softmax {
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip {
|
||||
scale: SCALE.powf(2.0).into(),
|
||||
input_scale: SCALE.into(),
|
||||
output_scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@@ -2557,6 +2587,6 @@ mod softmax {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied_par();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
200
src/commands.rs
200
src/commands.rs
@@ -1,4 +1,4 @@
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use clap::{Parser, Subcommand};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -9,8 +9,9 @@ use pyo3::{
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::{error::Error, str::FromStr};
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{pfsys::ProofType, RunArgs};
|
||||
|
||||
@@ -59,6 +60,8 @@ pub const DEFAULT_SOL_CODE_DA: &str = "evm_deploy_da.sol";
|
||||
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
|
||||
/// Default contract address for data attestation
|
||||
pub const DEFAULT_CONTRACT_ADDRESS_DA: &str = "contract_da.address";
|
||||
/// Default contract address for vk
|
||||
pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
|
||||
/// Default check mode
|
||||
pub const DEFAULT_CHECKMODE: &str = "safe";
|
||||
/// Default calibration target
|
||||
@@ -73,15 +76,21 @@ pub const DEFAULT_FUZZ_RUNS: &str = "10";
|
||||
pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
/// Default lookup safety margin
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
|
||||
/// Default render vk seperately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default VK sol path
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for TranscriptType {
|
||||
@@ -126,17 +135,27 @@ impl Default for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for CalibrationTarget {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
impl std::fmt::Display for CalibrationTarget {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_string()
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_string(),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_string(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for CalibrationTarget {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,6 +176,36 @@ impl From<&str> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
inner: H160,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<H160Flag> for H160 {
|
||||
fn from(val: H160Flag) -> H160 {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl ToFlags for H160Flag {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{:#x}", self.inner)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<&str> for H160Flag {
|
||||
fn from(s: &str) -> Self {
|
||||
Self {
|
||||
inner: H160::from_str(s).unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CalibrationTarget {
|
||||
@@ -189,7 +238,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
@@ -230,7 +279,7 @@ impl Cli {
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
#[cfg(feature = "empty-cmd")]
|
||||
/// Creates an empty buffer
|
||||
@@ -313,9 +362,20 @@ pub enum Commands {
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
allow_hyphen_values = true,
|
||||
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
|
||||
)]
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
|
||||
only_range_check_rebase: bool,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -389,6 +449,9 @@ pub enum Commands {
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
},
|
||||
/// Aggregates proofs :)
|
||||
Aggregate {
|
||||
@@ -451,6 +514,9 @@ pub enum Commands {
|
||||
/// The graph witness (optional - used to override fixed values in the circuit)
|
||||
#[arg(short = 'W', long)]
|
||||
witness: Option<PathBuf>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -473,11 +539,14 @@ pub enum Commands {
|
||||
/// number of fuzz iterations
|
||||
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
|
||||
num_runs: usize,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEVMData {
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -505,7 +574,7 @@ pub enum Commands {
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long)]
|
||||
addr: H160,
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
@@ -555,9 +624,9 @@ pub enum Commands {
|
||||
check_mode: CheckMode,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for a single proof
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEVMVerifier {
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -573,11 +642,36 @@ pub enum Commands {
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
|
||||
abi_path: PathBuf,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEvmVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI)]
|
||||
abi_path: PathBuf,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEVMDataAttestation {
|
||||
CreateEvmDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
@@ -597,9 +691,9 @@ pub enum Commands {
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an EVM verifier for an aggregate proof
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEVMVerifierAggr {
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -618,6 +712,11 @@ pub enum Commands {
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
@@ -633,6 +732,9 @@ pub enum Commands {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
},
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
@@ -669,6 +771,25 @@ pub enum Commands {
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
|
||||
/// The path to output the contract address
|
||||
addr_path: PathBuf,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
@@ -695,28 +816,23 @@ pub enum Commands {
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Verifies a proof using a local EVM executor, returning accept or reject
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEVM {
|
||||
VerifyEvm {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
addr_verifier: H160,
|
||||
addr_verifier: H160Flag,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long)]
|
||||
addr_da: Option<H160>,
|
||||
},
|
||||
|
||||
/// Print the proof in hexadecimal
|
||||
#[command(name = "print-proof-hex")]
|
||||
PrintProofHex {
|
||||
/// The path to the proof file
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
addr_da: Option<H160Flag>,
|
||||
// is the vk rendered seperately, if so specify an address
|
||||
#[arg(long)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
}
|
||||
|
||||
20
src/eth.rs
20
src/eth.rs
@@ -101,17 +101,18 @@ pub async fn setup_eth_backend(
|
||||
}
|
||||
|
||||
///
|
||||
pub async fn deploy_verifier_via_solidity(
|
||||
pub async fn deploy_contract_via_solidity(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
) -> Result<ethers::types::Address, Box<dyn Error>> {
|
||||
// anvil instance must be alive at least until the factory completes the deploy
|
||||
let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
get_contract_artifacts(sol_code_path, "Halo2Verifier", runs)?;
|
||||
get_contract_artifacts(sol_code_path, contract_name, runs)?;
|
||||
|
||||
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?;
|
||||
let contract = factory.deploy(())?.send().await?;
|
||||
@@ -335,11 +336,16 @@ pub async fn update_account_calls(
|
||||
pub async fn verify_proof_via_solidity(
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
addr: ethers::types::Address,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let flattened_instances = proof.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
|
||||
let encoded = encode_calldata(
|
||||
addr_vk.as_ref().map(|x| x.0),
|
||||
&proof.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
info!("encoded: {:#?}", hex::encode(&encoded));
|
||||
let (anvil, client) = setup_eth_backend(rpc_url, None).await?;
|
||||
@@ -439,6 +445,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
addr_verifier: ethers::types::Address,
|
||||
addr_da: ethers::types::Address,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
use ethers::abi::{Function, Param, ParamType, StateMutability, Token};
|
||||
@@ -452,8 +459,11 @@ pub async fn verify_proof_with_data_attestation(
|
||||
public_inputs.push(u);
|
||||
}
|
||||
|
||||
let encoded_verifier =
|
||||
encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
|
||||
let encoded_verifier = encode_calldata(
|
||||
addr_vk.as_ref().map(|x| x.0),
|
||||
&proof.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
info!("encoded: {:#?}", hex::encode(&encoded_verifier));
|
||||
|
||||
|
||||
485
src/execute.rs
485
src/execute.rs
@@ -3,7 +3,9 @@ use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::Commands;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
|
||||
use crate::commands::H160Flag;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(unused_imports)]
|
||||
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
|
||||
@@ -21,8 +23,6 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::RunArgs;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -140,8 +140,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
compiled_circuit,
|
||||
transcript,
|
||||
num_runs,
|
||||
} => fuzz(compiled_circuit, witness, transcript, num_runs),
|
||||
|
||||
compress_selectors,
|
||||
} => fuzz(
|
||||
compiled_circuit,
|
||||
witness,
|
||||
transcript,
|
||||
num_runs,
|
||||
compress_selectors,
|
||||
),
|
||||
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetSrs {
|
||||
@@ -170,7 +176,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -178,6 +186,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -192,28 +202,44 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifier {
|
||||
Commands::CreateEvmVerifier {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
render_vk_seperately,
|
||||
} => create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CreateEvmVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMDataAttestation {
|
||||
Commands::CreateEvmDataAttestation {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
data,
|
||||
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEVMVerifierAggr {
|
||||
Commands::CreateEvmVerifierAggr {
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
} => create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -221,6 +247,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
),
|
||||
Commands::CompileCircuit {
|
||||
model,
|
||||
@@ -233,9 +260,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness,
|
||||
} => setup(compiled_circuit, srs_path, vk_path, pk_path, witness),
|
||||
compress_selectors,
|
||||
} => setup(
|
||||
compiled_circuit,
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness,
|
||||
compress_selectors,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEVMData {
|
||||
Commands::SetupTestEvmData {
|
||||
data,
|
||||
compiled_circuit,
|
||||
test_data,
|
||||
@@ -296,6 +331,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
} => setup_aggregate(
|
||||
sample_snarks,
|
||||
vk_path,
|
||||
@@ -303,6 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
),
|
||||
Commands::Aggregate {
|
||||
proof_path,
|
||||
@@ -329,7 +366,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path)
|
||||
reduced_srs,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
proof_path,
|
||||
@@ -352,6 +390,25 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVK {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -377,13 +434,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::VerifyEVM {
|
||||
Commands::VerifyEvm {
|
||||
proof_path,
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
|
||||
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
|
||||
addr_vk,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -432,7 +489,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = sha256::digest(&std::fs::read(path.clone())?);
|
||||
let hash = sha256::digest(std::fs::read(path.clone())?);
|
||||
info!("SRS hash: {}", hash);
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
@@ -440,7 +497,7 @@ fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box
|
||||
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
|
||||
};
|
||||
|
||||
if hash != predefined_hash.to_string() {
|
||||
if hash != *predefined_hash {
|
||||
// delete file
|
||||
warn!("removing SRS file at {}", path.display());
|
||||
std::fs::remove_file(path)?;
|
||||
@@ -571,8 +628,12 @@ pub(crate) async fn gen_witness(
|
||||
);
|
||||
|
||||
if let Some(output_path) = output {
|
||||
serde_json::to_writer(&File::create(output_path)?, &witness)?;
|
||||
witness.save(output_path)?;
|
||||
}
|
||||
|
||||
// print the witness in debug
|
||||
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);
|
||||
|
||||
Ok(witness)
|
||||
}
|
||||
|
||||
@@ -662,15 +723,14 @@ impl AccuracyResults {
|
||||
let error = (original.clone() - calibrated.clone())?;
|
||||
let abs_error = error.map(|x| x.abs());
|
||||
let squared_error = error.map(|x| x.powi(2));
|
||||
let percentage_error =
|
||||
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error.into_iter());
|
||||
abs_errors.extend(abs_error.into_iter());
|
||||
squared_errors.extend(squared_error.into_iter());
|
||||
percentage_errors.extend(percentage_error.into_iter());
|
||||
abs_percentage_errors.extend(abs_percentage_error.into_iter());
|
||||
errors.extend(error);
|
||||
abs_errors.extend(abs_error);
|
||||
squared_errors.extend(squared_error);
|
||||
percentage_errors.extend(percentage_error);
|
||||
abs_percentage_errors.extend(abs_percentage_error);
|
||||
}
|
||||
|
||||
let mean_percent_error =
|
||||
@@ -679,29 +739,25 @@ impl AccuracyResults {
|
||||
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
|
||||
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
|
||||
let median_error = errors[errors.len() / 2];
|
||||
let max_error = errors
|
||||
let max_error = *errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
let min_error = errors
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
let min_error = *errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
|
||||
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
|
||||
let median_abs_error = abs_errors[abs_errors.len() / 2];
|
||||
let max_abs_error = abs_errors
|
||||
let max_abs_error = *abs_errors
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
let min_abs_error = abs_errors
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
let min_abs_error = *abs_errors
|
||||
.iter()
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap();
|
||||
|
||||
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
|
||||
|
||||
@@ -724,6 +780,7 @@ impl AccuracyResults {
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: PathBuf,
|
||||
@@ -731,6 +788,8 @@ pub(crate) fn calibrate(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
use std::collections::HashMap;
|
||||
@@ -768,16 +827,17 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
match target {
|
||||
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
|
||||
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
|
||||
}
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
let scale_rebase_multiplier = [1, 2, 10];
|
||||
|
||||
// 2 x 2 grid
|
||||
let range_grid = range
|
||||
.iter()
|
||||
@@ -813,18 +873,22 @@ pub(crate) fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
pb.set_message("calibrating...");
|
||||
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier
|
||||
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
// vec of settings copied chunks.len() times
|
||||
let run_args_iterable = vec![settings.run_args.clone(); chunks.len()];
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
@@ -836,41 +900,42 @@ pub(crate) fn calibrate(
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let tasks = chunks
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_q);
|
||||
debug!("circuit creation from run args failed: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
chunks
|
||||
.iter()
|
||||
.zip(run_args_iterable)
|
||||
.map(|(chunk, run_args)| {
|
||||
// we need to create a new run args for each chunk
|
||||
// time it
|
||||
.map(|chunk| {
|
||||
let chunk = chunk.clone();
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
..run_args.clone()
|
||||
};
|
||||
|
||||
let original_settings = settings.clone();
|
||||
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err(format!("failed to create circuit from run args"))
|
||||
as Result<GraphSettings, String>
|
||||
}
|
||||
};
|
||||
|
||||
let data = circuit
|
||||
.load_graph_from_file_exclusively(&chunk)
|
||||
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
|
||||
|
||||
let forward_res = circuit
|
||||
.calibrate(&data, max_logrows, lookup_safety_margin)
|
||||
.map_err(|e| format!("failed to calibrate: {}", e))?;
|
||||
.forward(&mut data.clone(), None, None)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
// push result to the hashmap
|
||||
forward_pass_res
|
||||
@@ -878,38 +943,32 @@ pub(crate) fn calibrate(
|
||||
.ok_or("key not found")?
|
||||
.push(forward_res);
|
||||
|
||||
let settings = circuit.settings().clone();
|
||||
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: settings.run_args.input_scale,
|
||||
param_scale: settings.run_args.param_scale,
|
||||
lookup_range: settings.run_args.lookup_range,
|
||||
logrows: settings.run_args.logrows,
|
||||
scale_rebase_multiplier: settings.run_args.scale_rebase_multiplier,
|
||||
..run_args.clone()
|
||||
};
|
||||
|
||||
let found_settings = GraphSettings {
|
||||
run_args: found_run_args,
|
||||
required_lookups: settings.required_lookups,
|
||||
model_output_scales: settings.model_output_scales,
|
||||
model_input_scales: settings.model_input_scales,
|
||||
num_rows: settings.num_rows,
|
||||
total_assignments: settings.total_assignments,
|
||||
total_const_size: settings.total_const_size,
|
||||
..original_settings.clone()
|
||||
};
|
||||
|
||||
Ok(found_settings) as Result<GraphSettings, String>
|
||||
Ok(()) as Result<(), String>
|
||||
})
|
||||
.collect::<Vec<Result<GraphSettings, String>>>();
|
||||
.collect::<Result<Vec<()>, String>>()?;
|
||||
|
||||
let mut res: Vec<GraphSettings> = vec![];
|
||||
for task in tasks {
|
||||
if let Ok(task) = task {
|
||||
res.push(task);
|
||||
}
|
||||
}
|
||||
let min_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.min_lookup_inputs)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.max_lookup_inputs)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let res = circuit.calibrate_from_min_max(
|
||||
min_lookup_range,
|
||||
max_lookup_range,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
);
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
@@ -917,31 +976,39 @@ pub(crate) fn calibrate(
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_q);
|
||||
|
||||
let max_lookup_range = res
|
||||
.iter()
|
||||
.map(|x| x.run_args.lookup_range.1)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let min_lookup_range = res
|
||||
.iter()
|
||||
.map(|x| x.run_args.lookup_range.0)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
if res.is_ok() {
|
||||
let new_settings = circuit.settings().clone();
|
||||
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
let found_settings = GraphSettings {
|
||||
run_args: found_run_args,
|
||||
required_lookups: new_settings.required_lookups,
|
||||
required_range_checks: new_settings.required_range_checks,
|
||||
model_output_scales: new_settings.model_output_scales,
|
||||
model_input_scales: new_settings.model_input_scales,
|
||||
num_rows: new_settings.num_rows,
|
||||
total_assignments: new_settings.total_assignments,
|
||||
total_const_size: new_settings.total_const_size,
|
||||
..settings.clone()
|
||||
};
|
||||
|
||||
found_params.push(found_settings.clone());
|
||||
|
||||
if let Some(mut best) = res.into_iter().max_by_key(|p| {
|
||||
(
|
||||
p.run_args.logrows,
|
||||
p.run_args.input_scale,
|
||||
p.run_args.param_scale,
|
||||
)
|
||||
}) {
|
||||
best.run_args.lookup_range = (min_lookup_range, max_lookup_range);
|
||||
// pick the one with the largest logrows
|
||||
found_params.push(best.clone());
|
||||
debug!(
|
||||
"found settings: \n {}",
|
||||
best.as_json()?.to_colored_json_auto()?
|
||||
found_settings.as_json()?.to_colored_json_auto()?
|
||||
);
|
||||
} else {
|
||||
debug!("calibration failed {}", res.err().unwrap());
|
||||
}
|
||||
|
||||
pb.inc(1);
|
||||
@@ -1034,7 +1101,7 @@ pub(crate) fn calibrate(
|
||||
|
||||
let tear_sheet_table = Table::new(vec![accuracy_res]);
|
||||
|
||||
println!(
|
||||
warn!(
|
||||
"\n\n <------------- Numerical Fidelity Report (input_scale: {}, param_scale: {}, scale_input_multiplier: {}) ------------->\n\n{}\n\n",
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
@@ -1098,21 +1165,11 @@ pub(crate) fn mock(
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify_par()
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
for instance in proof.instances {
|
||||
println!("{:?}", instance);
|
||||
}
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
info!("0x{}", hex_str);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
#[cfg(feature = "render")]
|
||||
pub(crate) fn render(
|
||||
model: PathBuf,
|
||||
@@ -1143,6 +1200,7 @@ pub(crate) fn create_evm_verifier(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
@@ -1160,7 +1218,11 @@ pub(crate) fn create_evm_verifier(
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
let verifier_solidity = generator.render()?;
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
@@ -1172,6 +1234,43 @@ pub(crate) fn create_evm_verifier(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_vk(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
|
||||
|
||||
let num_instance = circuit_settings.total_instances();
|
||||
let num_instance: usize = num_instance.iter().sum::<usize>();
|
||||
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
trace!("params computed");
|
||||
|
||||
let generator = halo2_solidity_verifier::SolidityGenerator::new(
|
||||
¶ms,
|
||||
&vk,
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
|
||||
let vk_solidity = generator.render_separately()?.1;
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
@@ -1267,13 +1366,15 @@ pub(crate) async fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let contract_address = deploy_verifier_via_solidity(
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1287,9 +1388,10 @@ pub(crate) async fn deploy_evm(
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
addr_verifier: H160,
|
||||
addr_verifier: H160Flag,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160>,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
check_solc_requirement();
|
||||
@@ -1299,13 +1401,20 @@ pub(crate) async fn verify_evm(
|
||||
let result = if let Some(addr_da) = addr_da {
|
||||
verify_proof_with_data_attestation(
|
||||
proof.clone(),
|
||||
addr_verifier,
|
||||
addr_da,
|
||||
addr_verifier.into(),
|
||||
addr_da.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
verify_proof_via_solidity(proof.clone(), addr_verifier, rpc_url.as_deref()).await?
|
||||
verify_proof_via_solidity(
|
||||
proof.clone(),
|
||||
addr_verifier.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!("Solidity verification result: {}", result);
|
||||
@@ -1325,6 +1434,7 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
abi_path: PathBuf,
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let srs_path = get_srs_path(logrows, srs_path);
|
||||
@@ -1363,7 +1473,11 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
|
||||
generator = generator.set_acc_encoding(Some(acc_encoding));
|
||||
|
||||
let verifier_solidity = generator.render()?;
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
@@ -1392,6 +1506,7 @@ pub(crate) fn setup(
|
||||
vk_path: PathBuf,
|
||||
pk_path: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit)?;
|
||||
@@ -1402,8 +1517,12 @@ pub(crate) fn setup(
|
||||
|
||||
let params = load_params_cmd(srs_path, circuit.settings().run_args.logrows)?;
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_path, pk.get_vk())?;
|
||||
save_pk::<KZGCommitmentScheme<Bn256>>(&pk_path, &pk)?;
|
||||
@@ -1451,14 +1570,14 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
use crate::pfsys::ProofType;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160,
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
check_solc_requirement();
|
||||
update_account_calls(addr, data, rpc_url.as_deref()).await?;
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
@@ -1542,6 +1661,7 @@ pub(crate) fn fuzz(
|
||||
data_path: PathBuf,
|
||||
transcript: TranscriptType,
|
||||
num_runs: usize,
|
||||
compress_selectors: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let passed = AtomicBool::new(true);
|
||||
@@ -1557,8 +1677,12 @@ pub(crate) fn fuzz(
|
||||
|
||||
let data = GraphWitness::from_path(data_path)?;
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
circuit.load_graph_witness(&data)?;
|
||||
|
||||
@@ -1574,9 +1698,12 @@ pub(crate) fn fuzz(
|
||||
let fuzz_pk = || {
|
||||
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
let bad_pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
|
||||
.map_err(|_| ())?;
|
||||
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
&new_params,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|_| ())?;
|
||||
|
||||
let bad_proof = create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
@@ -1595,6 +1722,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1625,6 +1753,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1647,9 +1776,12 @@ pub(crate) fn fuzz(
|
||||
let fuzz_vk = || {
|
||||
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
let bad_pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
|
||||
.map_err(|_| ())?;
|
||||
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
&new_params,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|_| ())?;
|
||||
|
||||
let bad_vk = bad_pk.get_vk();
|
||||
|
||||
@@ -1658,6 +1790,7 @@ pub(crate) fn fuzz(
|
||||
proof.clone(),
|
||||
bad_vk,
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1689,6 +1822,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1724,6 +1858,7 @@ pub(crate) fn fuzz(
|
||||
bad_proof,
|
||||
pk.get_vk(),
|
||||
strategy.clone(),
|
||||
params.n(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
};
|
||||
@@ -1809,7 +1944,7 @@ pub(crate) fn mock_aggregate(
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify_par()
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("Done.");
|
||||
@@ -1823,6 +1958,7 @@ pub(crate) fn setup_aggregate(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
compress_selectors: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// the K used for the aggregation circuit
|
||||
let params = load_params_cmd(srs_path, logrows)?;
|
||||
@@ -1833,8 +1969,11 @@ pub(crate) fn setup_aggregate(
|
||||
}
|
||||
|
||||
let agg_circuit = AggregationCircuit::new(¶ms.get_g()[0].into(), snarks, split_proofs)?;
|
||||
let agg_pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(&agg_circuit, ¶ms)?;
|
||||
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
|
||||
&agg_circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)?;
|
||||
|
||||
let agg_vk = agg_pk.get_vk();
|
||||
|
||||
@@ -1905,15 +2044,29 @@ pub(crate) fn verify(
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
reduced_srs: bool,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
|
||||
|
||||
let params = if reduced_srs {
|
||||
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
|
||||
} else {
|
||||
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
|
||||
};
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
|
||||
let vk =
|
||||
load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
"verify took {}.{}",
|
||||
@@ -1937,7 +2090,7 @@ pub(crate) fn verify_aggr(
|
||||
let strategy = AccumulatorStrategy::new(params.verifier_params());
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(vk_path, ())?;
|
||||
let now = Instant::now();
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy, 1 << logrows);
|
||||
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::circuit::InputType;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use postgres::{Client, NoTls};
|
||||
@@ -15,6 +16,8 @@ use pyo3::types::PyDict;
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -490,16 +493,20 @@ impl GraphData {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to open input at {}", path.display()))?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
reader.read_to_string(&mut buf)?;
|
||||
let graph_input = serde_json::from_str(&buf)?;
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
@@ -617,13 +624,13 @@ impl ToPyObject for DataSource {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use crate::pfsys::field_to_vecu64_montgomery;
|
||||
use crate::pfsys::field_to_string_montgomery;
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for FileSourceInner {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
FileSourceInner::Field(data) => field_to_vecu64_montgomery(data).to_object(py),
|
||||
FileSourceInner::Field(data) => field_to_string_montgomery(data).to_object(py),
|
||||
FileSourceInner::Bool(data) => data.to_object(py),
|
||||
FileSourceInner::Float(data) => data.to_object(py),
|
||||
}
|
||||
|
||||
317
src/graph/mod.rs
317
src/graph/mod.rs
@@ -16,6 +16,7 @@ use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSource;
|
||||
@@ -23,19 +24,22 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::RunArgs;
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
use log::{debug, error, trace, warn};
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
@@ -46,14 +50,13 @@ use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::ops::Deref;
|
||||
use thiserror::Error;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use crate::pfsys::field_to_vecu64_montgomery;
|
||||
use crate::pfsys::field_to_string_montgomery;
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
@@ -61,6 +64,20 @@ pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
/// Max circuit area
|
||||
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
|
||||
if let Ok(max_circuit_area) = std::env::var("EZKL_MAX_CIRCUIT_AREA") {
|
||||
Some(max_circuit_area.parse().unwrap_or(0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
@@ -191,42 +208,41 @@ impl GraphWitness {
|
||||
output_scales: Vec<crate::Scale>,
|
||||
visibility: VarVisibility,
|
||||
) {
|
||||
let mut pretty_elements = PrettyElements::default();
|
||||
pretty_elements.rescaled_inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
|
||||
pretty_elements.rescaled_outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
pretty_elements.outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect();
|
||||
let mut pretty_elements = PrettyElements {
|
||||
rescaled_inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = input_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
inputs: self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
rescaled_outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
let scale = output_scales[i];
|
||||
t.iter()
|
||||
.map(|x| dequantize(*x, scale, 0.).to_string())
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
outputs: self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(processed_inputs) = self.processed_inputs.clone() {
|
||||
pretty_elements.processed_inputs = processed_inputs
|
||||
@@ -292,16 +308,20 @@ impl GraphWitness {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut file = std::fs::File::open(path.clone())
|
||||
let file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load model at {}", path.display()))?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
|
||||
serde_json::to_writer(writer, &self).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
@@ -332,16 +352,16 @@ impl ToPyObject for GraphWitness {
|
||||
let dict_params = PyDict::new(py);
|
||||
let dict_outputs = PyDict::new(py);
|
||||
|
||||
let inputs: Vec<Vec<[u64; 4]>> = self
|
||||
let inputs: Vec<Vec<String>> = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
|
||||
.map(|x| x.iter().map(field_to_string_montgomery).collect())
|
||||
.collect();
|
||||
|
||||
let outputs: Vec<Vec<[u64; 4]>> = self
|
||||
let outputs: Vec<Vec<String>> = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
|
||||
.map(|x| x.iter().map(field_to_string_montgomery).collect())
|
||||
.collect();
|
||||
|
||||
dict.set_item("inputs", inputs).unwrap();
|
||||
@@ -389,9 +409,9 @@ impl ToPyObject for GraphWitness {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
|
||||
let poseidon_hash: Vec<[u64; 4]> = poseidon_hash
|
||||
let poseidon_hash: Vec<String> = poseidon_hash
|
||||
.iter()
|
||||
.map(field_to_vecu64_montgomery)
|
||||
.map(field_to_string_montgomery)
|
||||
.collect();
|
||||
pydict.set_item("poseidon_hash", poseidon_hash)?;
|
||||
|
||||
@@ -431,6 +451,8 @@ pub struct GraphSettings {
|
||||
pub module_sizes: ModuleSizes,
|
||||
/// required_lookups
|
||||
pub required_lookups: Vec<LookupOp>,
|
||||
/// required range_checks
|
||||
pub required_range_checks: Vec<Range>,
|
||||
/// check mode
|
||||
pub check_mode: CheckMode,
|
||||
/// ezkl version used
|
||||
@@ -454,22 +476,33 @@ impl GraphSettings {
|
||||
instances
|
||||
}
|
||||
|
||||
/// calculate the log2 of the total number of instances
|
||||
pub fn log2_total_instances(&self) -> u32 {
|
||||
let sum = self.total_instances().iter().sum::<usize>();
|
||||
|
||||
// max between 1 and the log2 of the sums
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// save params to file
|
||||
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
|
||||
let encoded = serde_json::to_string(&self)?;
|
||||
let mut file = std::fs::File::create(path)?;
|
||||
file.write_all(encoded.as_bytes())
|
||||
// buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, &self).map_err(|e| {
|
||||
error!("failed to save settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
}
|
||||
/// load params from file
|
||||
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
|
||||
let mut file = std::fs::File::open(path).map_err(|e| {
|
||||
error!("failed to open settings file at {}", e);
|
||||
e
|
||||
})?;
|
||||
let mut data = String::new();
|
||||
file.read_to_string(&mut data)?;
|
||||
let res = serde_json::from_str(&data)?;
|
||||
Ok(res)
|
||||
// buf reader
|
||||
let reader =
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
@@ -528,6 +561,7 @@ impl GraphSettings {
|
||||
pub struct GraphConfig {
|
||||
model_config: ModelConfig,
|
||||
module_configs: ModuleConfigs,
|
||||
circuit_size: CircuitSize,
|
||||
}
|
||||
|
||||
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
|
||||
@@ -564,7 +598,7 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -572,11 +606,10 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
// read bytes from file
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = std::fs::metadata(&path)?;
|
||||
let mut buffer = vec![0; metadata.len() as usize];
|
||||
f.read_exact(&mut buffer)?;
|
||||
let result = bincode::deserialize(&buffer)?;
|
||||
let f = std::fs::File::open(path)?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -591,6 +624,17 @@ pub enum TestDataSource {
|
||||
OnChain,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestDataSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TestDataSource::File => write!(f, "file"),
|
||||
TestDataSource::OnChain => write!(f, "on-chain"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TestDataSource {}
|
||||
|
||||
impl From<String> for TestDataSource {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
@@ -639,7 +683,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
// dummy module settings, must load from GraphData after
|
||||
let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?;
|
||||
let mut settings = model.gen_params(run_args, run_args.check_mode)?;
|
||||
|
||||
let mut num_params = 0;
|
||||
if !model.const_shapes().is_empty() {
|
||||
@@ -763,18 +807,18 @@ impl GraphCircuit {
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.rescaled_inputs = elements.rescaled_inputs.clone();
|
||||
public_inputs.inputs = elements.inputs.clone();
|
||||
} else if let Some(_) = &data.processed_inputs {
|
||||
} else if data.processed_inputs.is_some() {
|
||||
public_inputs.processed_inputs = elements.processed_inputs.clone();
|
||||
}
|
||||
|
||||
if let Some(_) = &data.processed_params {
|
||||
if data.processed_params.is_some() {
|
||||
public_inputs.processed_params = elements.processed_params.clone();
|
||||
}
|
||||
|
||||
if self.settings().run_args.output_visibility.is_public() {
|
||||
public_inputs.rescaled_outputs = elements.rescaled_outputs.clone();
|
||||
public_inputs.outputs = elements.outputs.clone();
|
||||
} else if let Some(_) = &data.processed_outputs {
|
||||
} else if data.processed_outputs.is_some() {
|
||||
public_inputs.processed_outputs = elements.processed_outputs.clone();
|
||||
}
|
||||
|
||||
@@ -807,7 +851,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
info!("input scales: {:?}", scales);
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
match &data.input_data {
|
||||
DataSource::File(file_data) => {
|
||||
@@ -826,7 +870,7 @@ impl GraphCircuit {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
info!("input scales: {:?}", scales);
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
@@ -956,19 +1000,24 @@ impl GraphCircuit {
|
||||
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(res: &GraphWitness, lookup_safety_margin: i128) -> (i128, i128) {
|
||||
fn calc_safe_lookup_range(
|
||||
min_lookup_inputs: i128,
|
||||
max_lookup_inputs: i128,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Range {
|
||||
let mut margin = (
|
||||
lookup_safety_margin * res.min_lookup_inputs,
|
||||
lookup_safety_margin * res.max_lookup_inputs,
|
||||
lookup_safety_margin * min_lookup_inputs,
|
||||
lookup_safety_margin * max_lookup_inputs,
|
||||
);
|
||||
if lookup_safety_margin == 1 {
|
||||
margin.0 -= 1;
|
||||
margin.1 += 1;
|
||||
margin.0 += 4;
|
||||
margin.1 += 4;
|
||||
}
|
||||
|
||||
margin
|
||||
}
|
||||
|
||||
fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(
|
||||
max_logrows as usize,
|
||||
Self::reserved_blinding_rows() as usize,
|
||||
@@ -978,7 +1027,8 @@ impl GraphCircuit {
|
||||
|
||||
fn calc_min_logrows(
|
||||
&mut self,
|
||||
res: &GraphWitness,
|
||||
min_lookup_inputs: i128,
|
||||
max_lookup_inputs: i128,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@@ -986,19 +1036,23 @@ impl GraphCircuit {
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
|
||||
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
|
||||
let mut min_logrows = MIN_LOGROWS;
|
||||
|
||||
let reserved_blinding_rows = Self::reserved_blinding_rows();
|
||||
// check if has overflowed max lookup input
|
||||
if res.max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
|| res.min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
if max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
|| min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
{
|
||||
let err_string = format!("max lookup input ({}) is too large", res.max_lookup_inputs);
|
||||
error!("{}", err_string);
|
||||
let err_string = format!("max lookup input ({}) is too large", max_lookup_inputs);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let safe_range = Self::calc_safe_lookup_range(res, lookup_safety_margin);
|
||||
let mut min_logrows = MIN_LOGROWS;
|
||||
let safe_range = Self::calc_safe_lookup_range(
|
||||
min_lookup_inputs,
|
||||
max_lookup_inputs,
|
||||
lookup_safety_margin,
|
||||
);
|
||||
|
||||
// degrade the max logrows until the extended k is small enough
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
@@ -1016,12 +1070,11 @@ impl GraphCircuit {
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
min_logrows
|
||||
);
|
||||
error!("{}", err_string);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
// degrade the max logrows until the extended k is small enough
|
||||
while max_logrows > min_logrows
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
max_logrows,
|
||||
Self::calc_num_cols(safe_range, max_logrows),
|
||||
@@ -1030,6 +1083,17 @@ impl GraphCircuit {
|
||||
max_logrows -= 1;
|
||||
}
|
||||
|
||||
if !self
|
||||
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
|
||||
{
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
|
||||
.log2()
|
||||
.ceil() as usize;
|
||||
@@ -1092,7 +1156,7 @@ impl GraphCircuit {
|
||||
|
||||
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
|
||||
|
||||
info!(
|
||||
debug!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
self.settings().run_args.lookup_range,
|
||||
self.settings().run_args.logrows
|
||||
@@ -1111,22 +1175,31 @@ impl GraphCircuit {
|
||||
// n = 2^k
|
||||
let n = 1u64 << k;
|
||||
let mut extended_k = k;
|
||||
|
||||
while (1 << extended_k) < (n * quotient_poly_degree) {
|
||||
extended_k += 1;
|
||||
if extended_k > bn256::Fr::S {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
extended_k <= bn256::Fr::S
|
||||
true
|
||||
}
|
||||
|
||||
/// Calibrate the circuit to the supplied data.
|
||||
pub fn calibrate(
|
||||
pub fn calibrate_from_min_max(
|
||||
&mut self,
|
||||
input: &[Tensor<Fp>],
|
||||
min_lookup_inputs: i128,
|
||||
max_lookup_inputs: i128,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let res = self.forward(&mut input.to_vec(), None, None)?;
|
||||
self.calc_min_logrows(&res, max_logrows, lookup_safety_margin)?;
|
||||
Ok(res)
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.calc_min_logrows(
|
||||
min_lookup_inputs,
|
||||
max_lookup_inputs,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the forward pass of the model / graph of computations and any associated hashing.
|
||||
@@ -1175,7 +1248,7 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
let mut model_results = self.model().forward(inputs)?;
|
||||
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1335,7 +1408,6 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
struct CircuitSize {
|
||||
num_instances: usize,
|
||||
@@ -1343,20 +1415,22 @@ struct CircuitSize {
|
||||
num_fixed: usize,
|
||||
num_challenges: usize,
|
||||
num_selectors: usize,
|
||||
logrows: u32,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl CircuitSize {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>) -> Self {
|
||||
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
|
||||
CircuitSize {
|
||||
num_instances: cs.num_instance_columns(),
|
||||
num_advice_columns: cs.num_advice_columns(),
|
||||
num_fixed: cs.num_fixed_columns(),
|
||||
num_challenges: cs.num_challenges(),
|
||||
num_selectors: cs.num_selectors(),
|
||||
logrows,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
@@ -1367,6 +1441,25 @@ impl CircuitSize {
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// number of columns
|
||||
pub fn num_columns(&self) -> usize {
|
||||
self.num_instances + self.num_advice_columns + self.num_fixed
|
||||
}
|
||||
|
||||
/// area of the circuit
|
||||
pub fn area(&self) -> usize {
|
||||
self.num_columns() * (1 << self.logrows)
|
||||
}
|
||||
|
||||
/// area less than max
|
||||
pub fn area_less_than_max(&self) -> bool {
|
||||
if EZKL_MAX_CIRCUIT_AREA.is_some() {
|
||||
self.area() < EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fp> for GraphCircuit {
|
||||
@@ -1428,6 +1521,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
params.run_args.lookup_range,
|
||||
params.run_args.logrows as usize,
|
||||
params.required_lookups,
|
||||
params.required_range_checks,
|
||||
params.check_mode,
|
||||
)
|
||||
.unwrap();
|
||||
@@ -1440,10 +1534,12 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
(cs.degree() as f32).log2().ceil()
|
||||
);
|
||||
|
||||
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"circuit size: \n {}",
|
||||
CircuitSize::from_cs(cs)
|
||||
circuit_size
|
||||
.as_json()
|
||||
.unwrap()
|
||||
.to_colored_json_auto()
|
||||
@@ -1453,6 +1549,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
GraphConfig {
|
||||
model_config,
|
||||
module_configs,
|
||||
circuit_size,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1465,6 +1562,16 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), PlonkError> {
|
||||
// check if the circuit area is less than the max
|
||||
if !config.circuit_size.area_less_than_max() {
|
||||
error!(
|
||||
"circuit area {} is larger than the max allowed area {}",
|
||||
config.circuit_size.area(),
|
||||
EZKL_MAX_CIRCUIT_AREA.unwrap()
|
||||
);
|
||||
return Err(PlonkError::Synthesis);
|
||||
}
|
||||
|
||||
trace!("Setting input in synthesize");
|
||||
let input_vis = &self.settings().run_args.input_visibility;
|
||||
let output_vis = &self.settings().run_args.output_visibility;
|
||||
|
||||
@@ -6,10 +6,10 @@ use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -56,6 +56,8 @@ use unzip_n::unzip_n;
|
||||
|
||||
unzip_n!(pub 3);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
|
||||
/// The result of a forward pass.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForwardResult {
|
||||
@@ -67,6 +69,16 @@ pub struct ForwardResult {
|
||||
pub min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
fn from(res: DummyPassRes) -> Self {
|
||||
Self {
|
||||
outputs: res.outputs,
|
||||
max_lookup_inputs: res.max_lookup_inputs,
|
||||
min_lookup_inputs: res.min_lookup_inputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelConfig {
|
||||
@@ -79,6 +91,27 @@ pub struct ModelConfig {
|
||||
/// Representation of execution graph
|
||||
pub type NodeGraph = BTreeMap<usize, NodeType>;
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct DummyPassRes {
|
||||
/// number of rows use
|
||||
pub num_rows: usize,
|
||||
/// linear coordinate
|
||||
pub linear_coord: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// lookup ops
|
||||
pub lookup_ops: HashSet<LookupOp>,
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i128,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Model {
|
||||
@@ -233,13 +266,7 @@ impl NodeType {
|
||||
NodeType::SubGraph { out_dims, .. } => out_dims.clone(),
|
||||
}
|
||||
}
|
||||
/// Returns the lookups required by a graph
|
||||
pub fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
match self {
|
||||
NodeType::Node(n) => n.opkind.required_lookups(),
|
||||
NodeType::SubGraph { model, .. } => model.required_lookups(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the scales of the node's output.
|
||||
pub fn out_scales(&self) -> Vec<crate::Scale> {
|
||||
match self {
|
||||
@@ -424,14 +451,6 @@ impl ParsedNodes {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.graph
|
||||
.nodes
|
||||
.values()
|
||||
.flat_map(|n| n.required_lookups())
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
/// Creates a `Model` from a specified path to an Onnx file.
|
||||
/// # Arguments
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
@@ -476,7 +495,7 @@ impl Model {
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
let instance_shapes = self.instance_shapes()?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {}",
|
||||
"model has".blue(),
|
||||
instance_shapes.len().to_string().blue(),
|
||||
@@ -484,36 +503,39 @@ impl Model {
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let (num_rows, linear_coord, total_const_size) =
|
||||
self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
// extract the requisite lookup ops from the model
|
||||
let mut lookup_ops: Vec<LookupOp> = self.required_lookups();
|
||||
let inputs: Vec<ValTensor<Fp>> = self
|
||||
.graph
|
||||
.input_shapes()?
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
if run_args.tolerance.val > 0.0 {
|
||||
for scale in self.graph.get_output_scales()? {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(scale).into();
|
||||
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
|
||||
lookup_ops.extend(opkind.required_lookups());
|
||||
}
|
||||
}
|
||||
|
||||
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
|
||||
lookup_ops.extend(set.into_iter().sorted());
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args: run_args.clone(),
|
||||
model_instance_shapes: instance_shapes,
|
||||
module_sizes: crate::graph::modules::ModuleSizes::default(),
|
||||
num_rows,
|
||||
total_assignments: linear_coord,
|
||||
required_lookups: lookup_ops,
|
||||
num_rows: res.num_rows,
|
||||
total_assignments: res.linear_coord,
|
||||
required_lookups: res.lookup_ops.into_iter().collect(),
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
total_const_size,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
num_blinding_factors: None,
|
||||
@@ -534,205 +556,17 @@ impl Model {
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
|
||||
/// * `run_args` - [RunArgs]
|
||||
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
|
||||
let mut max_lookup_inputs = 0;
|
||||
let mut min_lookup_inputs = 0;
|
||||
|
||||
let input_shapes = self.graph.input_shapes()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
let mut input = model_inputs[i].clone();
|
||||
input.reshape(&input_shapes[i])?;
|
||||
results.insert(input_idx, vec![input]);
|
||||
}
|
||||
|
||||
for (idx, n) in self.graph.nodes.iter() {
|
||||
let mut inputs = vec![];
|
||||
if n.is_input() {
|
||||
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
|
||||
inputs.push(t);
|
||||
} else {
|
||||
for (idx, outlet) in n.inputs().iter() {
|
||||
match results.get(&idx) {
|
||||
Some(value) => inputs.push(value[*outlet].clone()),
|
||||
None => return Err(Box::new(GraphError::MissingNode(*idx))),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("executing {}: {}", idx, n.as_str());
|
||||
debug!("dims: {:?}", n.out_dims());
|
||||
debug!(
|
||||
"input_dims: {:?}",
|
||||
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
if n.is_lookup() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &inputs {
|
||||
max = max.max(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.max()
|
||||
.ok_or("missing max")?,
|
||||
);
|
||||
min = min.min(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.min()
|
||||
.ok_or("missing min")?,
|
||||
);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("max lookup inputs: {}", max);
|
||||
debug!("min lookup inputs: {}", min);
|
||||
}
|
||||
|
||||
match n {
|
||||
NodeType::Node(n) => {
|
||||
// execute the op
|
||||
let start = instant::Instant::now();
|
||||
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
|
||||
res.output.reshape(&n.out_dims)?;
|
||||
let elapsed = start.elapsed();
|
||||
trace!("op took: {:?}", elapsed);
|
||||
// see if any of the intermediate lookup calcs are the max
|
||||
if !res.intermediate_lookups.is_empty() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &res.intermediate_lookups {
|
||||
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
|
||||
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("intermediate max lookup inputs: {}", max);
|
||||
debug!("intermediate min lookup inputs: {}", min);
|
||||
}
|
||||
debug!(
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
|
||||
idx,
|
||||
res.output.map(crate::fieldutils::felt_to_i32).show(),
|
||||
res.output
|
||||
.map(|x| crate::fieldutils::felt_to_f64(x)
|
||||
/ scale_to_multiplier(n.out_scale))
|
||||
.show(),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
|
||||
n.out_scale
|
||||
);
|
||||
results.insert(idx, vec![res.output]);
|
||||
}
|
||||
NodeType::SubGraph {
|
||||
model,
|
||||
output_mappings,
|
||||
input_mappings,
|
||||
inputs: input_tuple,
|
||||
..
|
||||
} => {
|
||||
let orig_inputs = inputs.clone();
|
||||
let input_mappings = input_mappings.clone();
|
||||
|
||||
let input_dims = inputs.iter().map(|inp| inp.dims());
|
||||
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
|
||||
|
||||
debug!(
|
||||
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
|
||||
num_iter, input_tuple, model.graph.inputs
|
||||
);
|
||||
|
||||
debug!("input_mappings: {:?}", input_mappings);
|
||||
|
||||
let mut full_results: Vec<Tensor<Fp>> = vec![];
|
||||
|
||||
for i in 0..num_iter {
|
||||
// replace the Stacked input with the current chunk iter
|
||||
for ((mapping, inp), og_input) in
|
||||
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
|
||||
{
|
||||
if let InputMapping::Stacked { axis, chunk } = mapping {
|
||||
let start = i * chunk;
|
||||
let end = (i + 1) * chunk;
|
||||
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
|
||||
*inp = t;
|
||||
}
|
||||
}
|
||||
|
||||
let res = model.forward(&inputs)?;
|
||||
// recursively get the max lookup inputs for subgraphs
|
||||
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
|
||||
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
|
||||
|
||||
let mut outlets = BTreeMap::new();
|
||||
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
|
||||
for mapping in mappings {
|
||||
match mapping {
|
||||
OutputMapping::Single { outlet, .. } => {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
OutputMapping::Stacked { outlet, axis, .. } => {
|
||||
if !full_results.is_empty() {
|
||||
let stacked_res = crate::tensor::ops::concat(
|
||||
&[&full_results[*outlet], &outlet_res],
|
||||
*axis,
|
||||
)?;
|
||||
|
||||
outlets.insert(outlet, stacked_res);
|
||||
} else {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
full_results = outlets.into_values().collect_vec();
|
||||
|
||||
let output_states = output_state_idx(output_mappings);
|
||||
let input_states = input_state_idx(&input_mappings);
|
||||
|
||||
assert_eq!(input_states.len(), output_states.len());
|
||||
|
||||
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
|
||||
inputs[*input_idx] = full_results[output_idx].clone();
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"------------ output subgraph node {}: {:?}",
|
||||
idx,
|
||||
full_results
|
||||
.iter()
|
||||
.map(|x|
|
||||
// convert to tensor i32
|
||||
x.map(crate::fieldutils::felt_to_i32).show())
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
results.insert(idx, full_results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let output_nodes = self.graph.outputs.iter();
|
||||
debug!(
|
||||
"model outputs are nodes: {:?}",
|
||||
output_nodes.clone().collect_vec()
|
||||
);
|
||||
let outputs = output_nodes
|
||||
.map(|(idx, outlet)| {
|
||||
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = ForwardResult {
|
||||
outputs,
|
||||
max_lookup_inputs,
|
||||
min_lookup_inputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
pub fn forward(
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(&run_args, &valtensor_inputs)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
/// Loads an Onnx model from a specified path.
|
||||
@@ -744,7 +578,7 @@ impl Model {
|
||||
fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
|
||||
) -> Result<TractResult, Box<dyn Error>> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
@@ -783,7 +617,7 @@ impl Model {
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
let symbol = model.symbol_table.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
info!("set {} to {}", symbol, value);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
}
|
||||
|
||||
// Note: do not optimize the model, as the layout will depend on underlying hardware
|
||||
@@ -1042,6 +876,8 @@ impl Model {
|
||||
&run_args.param_visibility,
|
||||
i,
|
||||
symbol_values,
|
||||
run_args.div_rebasing,
|
||||
run_args.rebase_frac_zero_constants,
|
||||
)?;
|
||||
if let Some(ref scales) = override_input_scales {
|
||||
if let Some(inp) = n.opkind.get_input() {
|
||||
@@ -1058,9 +894,20 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
} else {
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
@@ -1155,9 +1002,10 @@ impl Model {
|
||||
pub fn configure(
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
lookup_range: (i128, i128),
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
required_range_checks: Vec<Range>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
@@ -1170,12 +1018,16 @@ impl Model {
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[1];
|
||||
let index = &vars.advices[2];
|
||||
let output = &vars.advices[2];
|
||||
let index = &vars.advices[1];
|
||||
for op in required_lookups {
|
||||
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
|
||||
}
|
||||
|
||||
for range in required_range_checks {
|
||||
base_gate.configure_range_check(meta, input, range)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1216,6 +1068,7 @@ impl Model {
|
||||
let instance_idx = vars.get_instance_idx();
|
||||
|
||||
config.base.layout_tables(layouter)?;
|
||||
config.base.layout_range_checks(layouter)?;
|
||||
|
||||
let mut num_rows = 0;
|
||||
let mut linear_coord = 0;
|
||||
@@ -1285,7 +1138,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
num_rows.to_string().blue(),
|
||||
@@ -1327,18 +1180,29 @@ impl Model {
|
||||
};
|
||||
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants()
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
);
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
@@ -1481,28 +1345,13 @@ impl Model {
|
||||
pub fn dummy_layout(
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
input_shapes: &[Vec<usize>],
|
||||
) -> Result<(usize, usize, usize), Box<dyn Error>> {
|
||||
info!("calculating num of constraints using dummy model layout...");
|
||||
inputs: &[ValTensor<Fp>],
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = input_shapes
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
results.insert(*input_idx, vec![inputs[i].clone()]);
|
||||
@@ -1526,27 +1375,26 @@ impl Model {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let comparator = outputs
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let mut v: ValTensor<Fp> =
|
||||
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
|
||||
v.reshape(x.dims())?;
|
||||
Ok(v)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let mut comparator: ValTensor<Fp> =
|
||||
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let _ = outputs
|
||||
.into_iter()
|
||||
.zip(comparator)
|
||||
.map(|(o, c)| {
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[o, c],
|
||||
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
res?;
|
||||
} else if !self.visibility.output.is_private() {
|
||||
for output in &outputs {
|
||||
region.increment_total_constants(output.num_constants());
|
||||
@@ -1558,7 +1406,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
info!(
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
region.row().to_string().blue(),
|
||||
@@ -1567,11 +1415,26 @@ impl Model {
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
|
||||
Ok((
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
))
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
outputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Retrieves all constants from the model.
|
||||
|
||||
@@ -12,16 +12,12 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::trace;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -94,10 +90,6 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.inner.required_lookups()
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -126,12 +118,14 @@ impl Op<Fp> for Rescaled {
|
||||
pub struct RebaseScale {
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// the multiplier applied to the node output
|
||||
pub multiplier: f64,
|
||||
/// rebase op
|
||||
pub rebase_op: HybridOp,
|
||||
/// scale being rebased to
|
||||
pub target_scale: i32,
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// multiplier
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
@@ -141,6 +135,7 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -149,10 +144,15 @@ impl RebaseScale {
|
||||
let multiplier =
|
||||
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
|
||||
if let Some(op) = inner.get_rebased() {
|
||||
let multiplier = op.multiplier * multiplier;
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op.original_scale,
|
||||
})
|
||||
} else {
|
||||
@@ -160,6 +160,10 @@ impl RebaseScale {
|
||||
inner: Box::new(inner),
|
||||
target_scale: global_scale * scale_rebase_multiplier as i32,
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op_out_scale,
|
||||
})
|
||||
}
|
||||
@@ -173,15 +177,21 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
if let Some(op) = inner.get_rebased() {
|
||||
let multiplier = op.multiplier * multiplier;
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier: op.multiplier * multiplier,
|
||||
multiplier,
|
||||
original_scale: op.original_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
@@ -189,6 +199,10 @@ impl RebaseScale {
|
||||
target_scale,
|
||||
multiplier,
|
||||
original_scale: op_out_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
@@ -203,19 +217,17 @@ impl Op<Fp> for RebaseScale {
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let ri = res.output.map(felt_to_i128);
|
||||
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
|
||||
res.output = rescaled.map(i128_to_felt);
|
||||
|
||||
res.intermediate_lookups.push(ri);
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}) ({})",
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
self.multiplier,
|
||||
<HybridOp as Op<Fp>>::as_string(&self.rebase_op),
|
||||
self.inner.as_string()
|
||||
)
|
||||
}
|
||||
@@ -224,14 +236,6 @@ impl Op<Fp> for RebaseScale {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
let mut lookups = self.inner.required_lookups();
|
||||
lookups.push(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
});
|
||||
lookups
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -241,16 +245,8 @@ impl Op<Fp> for RebaseScale {
|
||||
let original_res = self
|
||||
.inner
|
||||
.layout(config, region, values)?
|
||||
.ok_or("no layout")?;
|
||||
|
||||
Ok(Some(crate::circuit::layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[original_res],
|
||||
&LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(self.multiplier as f32),
|
||||
},
|
||||
)?))
|
||||
.ok_or("no inner layout")?;
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
@@ -433,10 +429,6 @@ impl Op<Fp> for SupportedOp {
|
||||
self
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
self.as_op().required_lookups()
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
@@ -470,14 +462,7 @@ impl Tabled for Node {
|
||||
|
||||
fn headers() -> Vec<std::borrow::Cow<'static, str>> {
|
||||
let mut headers = Vec::with_capacity(Self::LENGTH);
|
||||
for i in [
|
||||
"idx",
|
||||
"opkind",
|
||||
"out_scale",
|
||||
"inputs",
|
||||
"out_dims",
|
||||
"required_lookups",
|
||||
] {
|
||||
for i in ["idx", "opkind", "out_scale", "inputs", "out_dims"] {
|
||||
headers.push(std::borrow::Cow::Borrowed(i));
|
||||
}
|
||||
headers
|
||||
@@ -490,14 +475,6 @@ impl Tabled for Node {
|
||||
fields.push(std::borrow::Cow::Owned(self.out_scale.to_string()));
|
||||
fields.push(std::borrow::Cow::Owned(display_vector(&self.inputs)));
|
||||
fields.push(std::borrow::Cow::Owned(display_vector(&self.out_dims)));
|
||||
fields.push(std::borrow::Cow::Owned(format!(
|
||||
"{:?}",
|
||||
self.opkind
|
||||
.required_lookups()
|
||||
.iter()
|
||||
.map(<LookupOp as Op<Fp>>::as_string)
|
||||
.collect_vec()
|
||||
)));
|
||||
fields
|
||||
}
|
||||
}
|
||||
@@ -520,6 +497,7 @@ impl Node {
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
other_nodes: &mut BTreeMap<usize, super::NodeType>,
|
||||
@@ -527,9 +505,9 @@ impl Node {
|
||||
param_visibility: &Visibility,
|
||||
idx: usize,
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
use log::warn;
|
||||
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
|
||||
@@ -567,6 +545,7 @@ impl Node {
|
||||
node.clone(),
|
||||
&mut inputs,
|
||||
symbol_values,
|
||||
rebase_frac_zero_constants,
|
||||
)?; // parses the op name
|
||||
|
||||
// we can only take the inputs as mutable once -- so we need to collect them first
|
||||
@@ -622,8 +601,6 @@ impl Node {
|
||||
input_node.bump_scale(out_scale);
|
||||
in_scales[input] = out_scale;
|
||||
}
|
||||
} else {
|
||||
warn!("input {} not found for rescaling, skipping ...", input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,7 +608,13 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -71,8 +71,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier - shift;
|
||||
float_rep
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
@@ -244,6 +243,7 @@ pub fn new_op_from_onnx(
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
inputs: &mut [super::NodeType],
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
|
||||
use crate::circuit::InputType;
|
||||
|
||||
@@ -262,7 +262,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[index].bump_scale(scale);
|
||||
c.rebase_scale(scale)?;
|
||||
inputs[index].replace_opkind(SupportedOp::Constant(c.clone()));
|
||||
Ok(SupportedOp::Linear(PolyOp::Identity))
|
||||
Ok(SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(scale),
|
||||
}))
|
||||
} else {
|
||||
Ok(default_op)
|
||||
}
|
||||
@@ -283,8 +285,8 @@ pub fn new_op_from_onnx(
|
||||
"shift left".to_string(),
|
||||
)));
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(1.0 / 2.0f32.powf(raw_values[0])),
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
@@ -305,8 +307,8 @@ pub fn new_op_from_onnx(
|
||||
"shift right".to_string(),
|
||||
)));
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Div {
|
||||
denom: crate::circuit::utils::F32(2.0f32.powf(raw_values[0])),
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
@@ -437,17 +439,16 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -476,17 +477,16 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
|
||||
// Raw values are always f32
|
||||
let raw_value = extract_tensor_value(op.0)?;
|
||||
// If bool or a tensor dimension then don't scale
|
||||
let constant_scale = match dt {
|
||||
let mut constant_scale = match dt {
|
||||
DatumType::Bool
|
||||
| DatumType::TDim
|
||||
| DatumType::I64
|
||||
@@ -560,6 +560,12 @@ pub fn new_op_from_onnx(
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
};
|
||||
|
||||
// if all raw_values are round then set scale to 0
|
||||
let all_round = raw_value.iter().all(|x| (x).fract() == 0.0);
|
||||
if all_round && rebase_frac_zero_constants {
|
||||
constant_scale = 0;
|
||||
}
|
||||
|
||||
// Quantize the raw value
|
||||
let quantized_value =
|
||||
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
|
||||
@@ -666,8 +672,10 @@ pub fn new_op_from_onnx(
|
||||
if unit == 0. {
|
||||
SupportedOp::Nonlinear(LookupOp::ReLU)
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
}
|
||||
@@ -708,8 +716,11 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
@@ -717,17 +728,13 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0];
|
||||
let additional_scale = if scale_diff > 0 {
|
||||
scale_to_multiplier(scale_diff)
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Recip {
|
||||
scale: (scale_to_multiplier(in_scale).powf(2.0) * additional_scale).into(),
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: false,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -752,7 +759,9 @@ pub fn new_op_from_onnx(
|
||||
"Scan" => {
|
||||
return Err("scan should never be analyzed explicitly".into());
|
||||
}
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => SupportedOp::Linear(PolyOp::Identity),
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
@@ -857,11 +866,11 @@ pub fn new_op_from_onnx(
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
SupportedOp::Linear(PolyOp::Identity)
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
}
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
|
||||
SupportedOp::Linear(PolyOp::Identity)
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
}
|
||||
@@ -886,12 +895,15 @@ pub fn new_op_from_onnx(
|
||||
let const_idx = const_idx[0];
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
op = SupportedOp::Nonlinear(LookupOp::Div {
|
||||
// we invert the constant for division
|
||||
denom: crate::circuit::utils::F32(1. / c.raw_values[0]),
|
||||
})
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
let raw_values = 1.0 / c.raw_values[0];
|
||||
if raw_values.log2().fract() == 0.0 {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
op = SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
use crate::tensor::{ValTensor, VarTensor};
|
||||
@@ -14,6 +15,7 @@ use pyo3::{
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -40,6 +42,33 @@ pub enum Visibility {
|
||||
Fixed,
|
||||
}
|
||||
|
||||
impl Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed {
|
||||
hash_is_public,
|
||||
outlets,
|
||||
} => {
|
||||
if *hash_is_public {
|
||||
write!(f, "hashed/public")
|
||||
} else {
|
||||
write!(f, "hashed/private/{}", outlets.iter().join(","))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Visibility {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
@@ -202,17 +231,6 @@ impl Visibility {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
impl std::fmt::Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
Visibility::Hashed { .. } => write!(f, "hashed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
@@ -237,6 +255,11 @@ impl VarScales {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
Ok(Self {
|
||||
@@ -405,7 +428,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
num_constants: usize,
|
||||
module_requires_fixed: bool,
|
||||
) -> Self {
|
||||
info!("number of blinding factors: {}", cs.blinding_factors());
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
|
||||
|
||||
80
src/lib.rs
80
src/lib.rs
@@ -7,7 +7,6 @@
|
||||
overflowing_literals,
|
||||
path_statements,
|
||||
patterns_in_fns_without_body,
|
||||
private_in_public,
|
||||
unconditional_recursion,
|
||||
unused,
|
||||
unused_allocation,
|
||||
@@ -29,10 +28,11 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
use circuit::Tolerance;
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
|
||||
pub mod circuit;
|
||||
@@ -71,11 +71,34 @@ pub mod tensor;
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub mod wasm;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
pub type Scale = i32;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
// Buf writer capacity
|
||||
lazy_static! {
|
||||
/// The capacity of the buffer used for writing to disk
|
||||
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_KEY_FORMAT: &str = "raw-bytes";
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
#[arg(short = 'T', long, default_value = "0")]
|
||||
@@ -90,8 +113,8 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1")]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
|
||||
pub lookup_range: (i128, i128),
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
pub logrows: u32,
|
||||
@@ -99,7 +122,7 @@ pub struct RunArgs {
|
||||
#[arg(short = 'N', long, default_value = "2")]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
@@ -110,6 +133,15 @@ pub struct RunArgs {
|
||||
/// Flags whether params are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[arg(long, default_value = "false")]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[arg(long, default_value = "unsafe")]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -126,6 +158,9 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,34 +204,15 @@ fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr,
|
||||
T: std::str::FromStr + std::fmt::Debug,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
U: std::str::FromStr,
|
||||
U: std::str::FromStr + std::fmt::Debug,
|
||||
U::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let pos = s
|
||||
.find('=')
|
||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
||||
}
|
||||
|
||||
/// Parse a tuple
|
||||
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr + Clone,
|
||||
T::Err: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
|
||||
|
||||
let res = res
|
||||
.map(|x| {
|
||||
// remove blank space
|
||||
let x = x.trim();
|
||||
x.parse::<T>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if res.len() != 2 {
|
||||
return Err("invalid tuple".into());
|
||||
}
|
||||
Ok((res[0].clone(), res[1].clone()))
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
132
src/pfsys/mod.rs
132
src/pfsys/mod.rs
@@ -8,10 +8,11 @@ use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation::PoseidonTranscript;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG};
|
||||
@@ -39,9 +40,19 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error as thisError;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
match s {
|
||||
"processed" => halo2_proofs::SerdeFormat::Processed,
|
||||
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
|
||||
_ => panic!("invalid serde format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(
|
||||
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
|
||||
@@ -52,6 +63,25 @@ pub enum ProofType {
|
||||
ForAggr,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProofType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ProofType::Single => "single",
|
||||
ProofType::ForAggr => "for-aggr",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for ProofType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProofType> for TranscriptType {
|
||||
fn from(val: ProofType) -> Self {
|
||||
match val {
|
||||
@@ -154,6 +184,21 @@ pub enum TranscriptType {
|
||||
EVM,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TranscriptType {}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for TranscriptType {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -167,8 +212,8 @@ impl ToPyObject for TranscriptType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
|
||||
let g1affine_x = field_to_vecu64_montgomery(&g1affine.x);
|
||||
let g1affine_y = field_to_vecu64_montgomery(&g1affine.y);
|
||||
let g1affine_x = field_to_string_montgomery(&g1affine.x);
|
||||
let g1affine_y = field_to_string_montgomery(&g1affine.y);
|
||||
g1affine_dict.set_item("x", g1affine_x).unwrap();
|
||||
g1affine_dict.set_item("y", g1affine_y).unwrap();
|
||||
}
|
||||
@@ -178,24 +223,24 @@ use halo2curves::bn256::G1;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
|
||||
let g1_x = field_to_vecu64_montgomery(&g1.x);
|
||||
let g1_y = field_to_vecu64_montgomery(&g1.y);
|
||||
let g1_z = field_to_vecu64_montgomery(&g1.z);
|
||||
let g1_x = field_to_string_montgomery(&g1.x);
|
||||
let g1_y = field_to_string_montgomery(&g1.y);
|
||||
let g1_z = field_to_string_montgomery(&g1.z);
|
||||
g1_dict.set_item("x", g1_x).unwrap();
|
||||
g1_dict.set_item("y", g1_y).unwrap();
|
||||
g1_dict.set_item("z", g1_z).unwrap();
|
||||
}
|
||||
|
||||
/// converts fp into `Vec<u64>` in Montgomery form
|
||||
pub fn field_to_vecu64_montgomery<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> [u64; 4] {
|
||||
pub fn field_to_string_montgomery<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
|
||||
let repr = serde_json::to_string(&fp).unwrap();
|
||||
let b: [u64; 4] = serde_json::from_str(&repr).unwrap();
|
||||
let b: String = serde_json::from_str(&repr).unwrap();
|
||||
b
|
||||
}
|
||||
|
||||
/// converts `Vec<u64>` in Montgomery form into fp
|
||||
pub fn vecu64_to_field_montgomery<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
|
||||
b: &[u64; 4],
|
||||
pub fn string_to_field_montgomery<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
|
||||
b: &String,
|
||||
) -> F {
|
||||
let repr = serde_json::to_string(&b).unwrap();
|
||||
let fp: F = serde_json::from_str(&repr).unwrap();
|
||||
@@ -256,10 +301,10 @@ where
|
||||
{
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
let field_elems: Vec<Vec<[u64; 4]>> = self
|
||||
let field_elems: Vec<Vec<String>> = self
|
||||
.instances
|
||||
.iter()
|
||||
.map(|x| x.iter().map(|fp| field_to_vecu64_montgomery(fp)).collect())
|
||||
.map(|x| x.iter().map(|fp| field_to_string_montgomery(fp)).collect())
|
||||
.collect::<Vec<_>>();
|
||||
dict.set_item("instances", field_elems).unwrap();
|
||||
let hex_proof = hex::encode(&self.proof);
|
||||
@@ -306,10 +351,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// create hex proof from proof
|
||||
pub fn create_hex_proof(&mut self) {
|
||||
let hex_proof = hex::encode(&self.proof);
|
||||
self.hex_proof = Some(format!("0x{}", hex_proof));
|
||||
}
|
||||
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -322,8 +373,10 @@ where
|
||||
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
|
||||
{
|
||||
trace!("reading proof");
|
||||
let data = std::fs::read_to_string(proof_path)?;
|
||||
serde_json::from_str(&data).map_err(|e| e.into())
|
||||
let file = std::fs::File::open(proof_path)?;
|
||||
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let proof: Self = serde_json::from_reader(reader)?;
|
||||
Ok(proof)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,6 +480,7 @@ where
|
||||
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
compress_selectors: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
@@ -438,7 +492,7 @@ where
|
||||
// Initialize verifying key
|
||||
let now = Instant::now();
|
||||
trace!("preparing VK");
|
||||
let vk = keygen_vk(params, &empty_circuit)?;
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
|
||||
let elapsed = now.elapsed();
|
||||
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());
|
||||
|
||||
@@ -548,6 +602,7 @@ where
|
||||
verifier_params,
|
||||
pk.get_vk(),
|
||||
strategy,
|
||||
verifier_params.n(),
|
||||
)?;
|
||||
}
|
||||
let elapsed = now.elapsed();
|
||||
@@ -594,6 +649,7 @@ where
|
||||
let mut snark_new = snark.clone();
|
||||
// swap the proof bytes for the new ones
|
||||
snark_new.proof[..proof_first_bytes.len()].copy_from_slice(&proof_first_bytes);
|
||||
snark_new.create_hex_proof();
|
||||
|
||||
Ok(snark_new)
|
||||
}
|
||||
@@ -634,6 +690,7 @@ pub fn verify_proof_circuit<
|
||||
params: &'params Scheme::ParamsVerifier,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
@@ -654,7 +711,7 @@ where
|
||||
trace!("instances {:?}", instances);
|
||||
|
||||
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
|
||||
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
|
||||
}
|
||||
|
||||
/// Loads a [VerifyingKey] at `path`.
|
||||
@@ -670,13 +727,14 @@ where
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
)?;
|
||||
info!("done loading verification key ✅");
|
||||
Ok(vk)
|
||||
}
|
||||
|
||||
/// Loads a [ProvingKey] at `path`.
|
||||
@@ -692,19 +750,20 @@ where
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)
|
||||
)?;
|
||||
info!("done loading proving key ✅");
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
/// Saves a [ProvingKey] to `path`.
|
||||
pub fn save_pk<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
vk: &ProvingKey<Scheme::Curve>,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
) -> Result<(), io::Error>
|
||||
where
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
@@ -712,9 +771,10 @@ where
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
writer.flush()?;
|
||||
info!("done saving proving key ✅");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -729,9 +789,10 @@ where
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
writer.flush()?;
|
||||
info!("done saving verification key ✅");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -742,7 +803,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::new(f);
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
params.write(&mut writer)?;
|
||||
writer.flush()?;
|
||||
Ok(())
|
||||
@@ -832,6 +893,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
vk: &VerifyingKey<G1Affine>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
@@ -841,7 +903,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, params, vk, strategy),
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => verify_proof_circuit::<
|
||||
Fr,
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
@@ -849,7 +911,7 @@ pub(crate) fn verify_proof_circuit_kzg<
|
||||
_,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, params, vk, strategy),
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
211
src/python.rs
211
src/python.rs
@@ -15,10 +15,9 @@ use crate::graph::{
|
||||
use crate::pfsys::evm::aggregation::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
|
||||
Snark, TranscriptType,
|
||||
TranscriptType,
|
||||
};
|
||||
use crate::RunArgs;
|
||||
use ethers::types::H160;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
@@ -26,11 +25,10 @@ use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
type PyFelt = [u64; 4];
|
||||
type PyFelt = String;
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -65,9 +63,9 @@ struct PyG1 {
|
||||
impl From<G1> for PyG1 {
|
||||
fn from(g1: G1) -> Self {
|
||||
PyG1 {
|
||||
x: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.x),
|
||||
y: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.y),
|
||||
z: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.z),
|
||||
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
|
||||
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
|
||||
z: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.z),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -75,9 +73,9 @@ impl From<G1> for PyG1 {
|
||||
impl From<PyG1> for G1 {
|
||||
fn from(val: PyG1) -> Self {
|
||||
G1 {
|
||||
x: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.x),
|
||||
y: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.y),
|
||||
z: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.z),
|
||||
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
|
||||
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
|
||||
z: crate::pfsys::string_to_field_montgomery::<Fq>(&val.z),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -108,8 +106,8 @@ pub struct PyG1Affine {
|
||||
impl From<G1Affine> for PyG1Affine {
|
||||
fn from(g1: G1Affine) -> Self {
|
||||
PyG1Affine {
|
||||
x: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.x),
|
||||
y: crate::pfsys::field_to_vecu64_montgomery::<Fq>(&g1.y),
|
||||
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
|
||||
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,8 +115,8 @@ impl From<G1Affine> for PyG1Affine {
|
||||
impl From<PyG1Affine> for G1Affine {
|
||||
fn from(val: PyG1Affine) -> Self {
|
||||
G1Affine {
|
||||
x: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.x),
|
||||
y: crate::pfsys::vecu64_to_field_montgomery::<Fq>(&val.y),
|
||||
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
|
||||
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,7 +144,7 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
#[pyo3(get, set)]
|
||||
pub lookup_range: (i128, i128),
|
||||
pub lookup_range: crate::circuit::table::Range,
|
||||
#[pyo3(get, set)]
|
||||
pub logrows: u32,
|
||||
#[pyo3(get, set)]
|
||||
@@ -159,6 +157,12 @@ struct PyRunArgs {
|
||||
pub param_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub check_mode: CheckMode,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -185,6 +189,9 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,6 +210,9 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -211,10 +221,10 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
#[pyfunction(signature = (
|
||||
array,
|
||||
))]
|
||||
fn vecu64_to_felt(array: PyFelt) -> PyResult<String> {
|
||||
fn string_to_felt(array: PyFelt) -> PyResult<String> {
|
||||
Ok(format!(
|
||||
"{:?}",
|
||||
crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array)
|
||||
crate::pfsys::string_to_field_montgomery::<Fr>(&array)
|
||||
))
|
||||
}
|
||||
|
||||
@@ -222,8 +232,8 @@ fn vecu64_to_felt(array: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
array,
|
||||
))]
|
||||
fn vecu64_to_int(array: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array);
|
||||
fn string_to_int(array: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
@@ -233,8 +243,8 @@ fn vecu64_to_int(array: PyFelt) -> PyResult<i128> {
|
||||
array,
|
||||
scale
|
||||
))]
|
||||
fn vecu64_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::vecu64_to_field_montgomery::<Fr>(&array);
|
||||
fn string_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
@@ -246,11 +256,11 @@ fn vecu64_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
input,
|
||||
scale
|
||||
))]
|
||||
fn float_to_vecu64(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
fn float_to_string(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = i128_to_felt(int_rep);
|
||||
Ok(crate::pfsys::field_to_vecu64_montgomery::<Fr>(&felt))
|
||||
Ok(crate::pfsys::field_to_string_montgomery::<Fr>(&felt))
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
|
||||
@@ -318,7 +328,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
.map(crate::pfsys::vecu64_to_field_montgomery::<Fr>)
|
||||
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output =
|
||||
@@ -329,7 +339,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
|
||||
let hash = output[0]
|
||||
.iter()
|
||||
.map(crate::pfsys::field_to_vecu64_montgomery::<Fr>)
|
||||
.map(crate::pfsys::field_to_string_montgomery::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
Ok(hash)
|
||||
}
|
||||
@@ -337,8 +347,8 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
/// Generate a kzg commitment.
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path,
|
||||
settings_path,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
fn kzg_commit(
|
||||
@@ -349,7 +359,7 @@ fn kzg_commit(
|
||||
) -> PyResult<Vec<PyG1Affine>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
.map(crate::pfsys::vecu64_to_field_montgomery::<Fr>)
|
||||
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)
|
||||
@@ -387,9 +397,9 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
|
||||
|
||||
/// Generates a vk from a pk for a model circuit and saves it to a file
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk,
|
||||
circuit_settings_path,
|
||||
vk_output_path
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK),
|
||||
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK),
|
||||
))]
|
||||
fn gen_vk_from_pk_single(
|
||||
path_to_pk: PathBuf,
|
||||
@@ -413,8 +423,8 @@ fn gen_vk_from_pk_single(
|
||||
|
||||
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk,
|
||||
vk_output_path
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(path_to_pk, ())
|
||||
@@ -511,7 +521,9 @@ fn gen_settings(
|
||||
target = CalibrationTarget::default(), // default is "resources
|
||||
lookup_safety_margin = DEFAULT_LOOKUP_SAFETY_MARGIN.parse().unwrap(),
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
data: PathBuf,
|
||||
@@ -520,7 +532,9 @@ fn calibrate_settings(
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
@@ -529,6 +543,8 @@ fn calibrate_settings(
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -543,7 +559,7 @@ fn calibrate_settings(
|
||||
#[pyfunction(signature = (
|
||||
data=PathBuf::from(DEFAULT_DATA),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=None,
|
||||
output=PathBuf::from(DEFAULT_WITNESS),
|
||||
vk_path=None,
|
||||
srs_path=None,
|
||||
))]
|
||||
@@ -604,7 +620,8 @@ fn mock_aggregate(
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
pk_path=PathBuf::from(DEFAULT_PK),
|
||||
srs_path=None,
|
||||
witness_path = None
|
||||
witness_path = None,
|
||||
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
|
||||
))]
|
||||
fn setup(
|
||||
model: PathBuf,
|
||||
@@ -612,8 +629,17 @@ fn setup(
|
||||
pk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
witness_path: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup(model, srs_path, vk_path, pk_path, witness_path).map_err(|e| {
|
||||
crate::execute::setup(
|
||||
model,
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness_path,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run setup: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -661,14 +687,23 @@ fn prove(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
srs_path=None,
|
||||
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
|
||||
))]
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
non_reduced_srs: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
|
||||
crate::execute::verify(
|
||||
proof_path,
|
||||
settings_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
non_reduced_srs,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -682,7 +717,8 @@ fn verify(
|
||||
pk_path=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
srs_path = None
|
||||
srs_path = None,
|
||||
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
|
||||
))]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
@@ -691,6 +727,7 @@ fn setup_aggregate(
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup_aggregate(
|
||||
sample_snarks,
|
||||
@@ -699,6 +736,7 @@ fn setup_aggregate(
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to setup aggregate: {}", e);
|
||||
@@ -794,6 +832,7 @@ fn verify_aggr(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier(
|
||||
vk_path: PathBuf,
|
||||
@@ -801,12 +840,20 @@ fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
@@ -872,7 +919,7 @@ fn setup_test_evm_witness(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
@@ -889,6 +936,39 @@ fn deploy_evm(
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_vk_evm(
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
@@ -940,26 +1020,28 @@ fn deploy_da_evm(
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
rpc_url=None,
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
fn verify_evm(
|
||||
addr_verifier: &str,
|
||||
proof_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<&str>,
|
||||
addr_vk: Option<&str>,
|
||||
) -> Result<bool, PyErr> {
|
||||
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_verifier = H160Flag::from(addr_verifier);
|
||||
let addr_da = if let Some(addr_da) = addr_da {
|
||||
let addr_da = H160::from_str(addr_da).map_err(|e| {
|
||||
let err_str = format!("address is invalid: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
let addr_da = H160Flag::from(addr_da);
|
||||
Some(addr_da)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let addr_vk = if let Some(addr_vk) = addr_vk {
|
||||
let addr_vk = H160Flag::from(addr_vk);
|
||||
Some(addr_vk)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
@@ -968,6 +1050,7 @@ fn verify_evm(
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_evm: {}", e);
|
||||
@@ -985,6 +1068,7 @@ fn verify_evm(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier_aggr(
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
@@ -993,6 +1077,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path: PathBuf,
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
@@ -1001,6 +1086,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
|
||||
@@ -1009,32 +1095,21 @@ fn create_evm_verifier_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[pyfunction(signature = (proof_path))]
|
||||
fn print_proof_hex(proof_path: PathBuf) -> Result<String, PyErr> {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load proof"))?;
|
||||
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
// NOTE: DeployVerifierEVM and SendProofEVM will be implemented in python in pyezkl
|
||||
pyo3_log::init();
|
||||
m.add_class::<PyRunArgs>()?;
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add_function(wrap_pyfunction!(vecu64_to_felt, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(string_to_felt, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(string_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(string_to_float, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(float_to_string, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;
|
||||
@@ -1055,9 +1130,9 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
|
||||
@@ -30,11 +30,11 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use std::cmp::max;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner = self.inner.clone();
|
||||
let remainder = self.len() % n;
|
||||
if remainder != 0 {
|
||||
inner.resize(self.len() + n - remainder, T::zero().unwrap());
|
||||
inner.resize(self.len() + n - remainder, pad);
|
||||
}
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
}
|
||||
@@ -1452,6 +1452,43 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
/// # Arguments
|
||||
/// * `self` - Tensor
|
||||
/// * `rhs` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the broadcasted shape of two tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::get_broadcasted_shape;
|
||||
|
||||
@@ -243,7 +243,7 @@ pub fn and<
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = equals(&a, &b).unwrap().0;
|
||||
/// let result = equals(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
@@ -260,7 +260,7 @@ pub fn equals<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let a = a.clone();
|
||||
let b = b.clone();
|
||||
|
||||
@@ -268,7 +268,7 @@ pub fn equals<
|
||||
|
||||
let result = nonlinearities::kronecker_delta(&diff);
|
||||
|
||||
Ok((result, vec![diff]))
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Greater than operation.
|
||||
@@ -289,7 +289,7 @@ pub fn equals<
|
||||
/// ).unwrap();
|
||||
/// let result = greater(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater<
|
||||
T: TensorType
|
||||
@@ -302,7 +302,7 @@ pub fn greater<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -311,7 +311,7 @@ pub fn greater<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok((mask, vec![mask_inter]))
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -332,7 +332,7 @@ pub fn greater<
|
||||
/// ).unwrap();
|
||||
/// let result = greater_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_equal<
|
||||
T: TensorType
|
||||
@@ -345,7 +345,7 @@ pub fn greater_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -354,7 +354,7 @@ pub fn greater_equal<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok((mask, vec![mask_inter]))
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
/// Less than to operation.
|
||||
@@ -375,7 +375,7 @@ pub fn greater_equal<
|
||||
/// ).unwrap();
|
||||
/// let result = less(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less<
|
||||
@@ -389,7 +389,7 @@ pub fn less<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater(b, a)
|
||||
}
|
||||
@@ -412,7 +412,7 @@ pub fn less<
|
||||
/// ).unwrap();
|
||||
/// let result = less_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less_equal<
|
||||
@@ -426,7 +426,7 @@ pub fn less_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater_equal(b, a)
|
||||
}
|
||||
@@ -950,8 +950,7 @@ pub fn neg<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sy
|
||||
/// Elementwise multiplies multiple tensors.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Tensor
|
||||
/// * `t` - Tensors
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -2301,12 +2300,12 @@ pub fn deconv<
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
///
|
||||
/// // This time with normalization
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
/// ```
|
||||
@@ -2316,7 +2315,7 @@ pub fn sumpool(
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
normalize: bool,
|
||||
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
|
||||
) -> Result<Tensor<i128>, TensorError> {
|
||||
let image_dims = image.dims();
|
||||
let batch_size = image_dims[0];
|
||||
let image_channels = image_dims[1];
|
||||
@@ -2346,15 +2345,12 @@ pub fn sumpool(
|
||||
let mut combined = res.combine()?;
|
||||
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
let mut inter = vec![];
|
||||
|
||||
if normalize {
|
||||
inter.push(combined.clone());
|
||||
let norm = kernel.len();
|
||||
combined = nonlinearities::const_div(&combined, norm as f64);
|
||||
}
|
||||
|
||||
Ok((combined, inter))
|
||||
Ok(combined)
|
||||
}
|
||||
|
||||
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
|
||||
@@ -3049,11 +3045,7 @@ pub mod nonlinearities {
|
||||
}
|
||||
|
||||
/// softmax layout
|
||||
pub fn softmax_axes(
|
||||
a: &Tensor<i128>,
|
||||
scale: f64,
|
||||
axes: &[usize],
|
||||
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
|
||||
// we want this to be as small as possible so we set the output scale to 1
|
||||
let dims = a.dims();
|
||||
|
||||
@@ -3061,8 +3053,6 @@ pub mod nonlinearities {
|
||||
return softmax(a, scale);
|
||||
}
|
||||
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
let cartesian_coord = dims[..dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
@@ -3085,8 +3075,7 @@ pub mod nonlinearities {
|
||||
|
||||
let res = softmax(&softmax_input, scale);
|
||||
|
||||
outputs.push(res.0);
|
||||
intermediate_values.extend(res.1);
|
||||
outputs.push(res);
|
||||
}
|
||||
|
||||
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
|
||||
@@ -3094,7 +3083,7 @@ pub mod nonlinearities {
|
||||
.combine()
|
||||
.unwrap();
|
||||
res.reshape(dims).unwrap();
|
||||
(res, intermediate_values)
|
||||
res
|
||||
}
|
||||
|
||||
/// Applies softmax
|
||||
@@ -3111,24 +3100,20 @@ pub mod nonlinearities {
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = softmax(&x, 128.0).0;
|
||||
/// let result = softmax(&x, 128.0);
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
intermediate_values.push(a.clone());
|
||||
|
||||
let exp = exp(a, scale);
|
||||
|
||||
let sum = sum(&exp).unwrap();
|
||||
intermediate_values.push(sum.clone());
|
||||
let inv_denom = recip(&sum, scale.powf(2.0));
|
||||
let inv_denom = recip(&sum, scale, scale);
|
||||
|
||||
((exp * inv_denom).unwrap(), intermediate_values)
|
||||
(exp * inv_denom).unwrap()
|
||||
}
|
||||
|
||||
/// Applies range_check_percent
|
||||
@@ -3163,7 +3148,7 @@ pub mod nonlinearities {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
let scale = input_scale * output_scale;
|
||||
let diff: Tensor<i128> = sub(t).unwrap();
|
||||
let recip = recip(&t[0], scale as f64);
|
||||
let recip = recip(&t[0], input_scale as f64, output_scale as f64);
|
||||
let product = mult(&[diff, recip]).unwrap();
|
||||
let _tol = ((tol / 100.0) * scale as f32).round() as f64;
|
||||
let upper_bound = greater_than(&product, _tol);
|
||||
@@ -3774,14 +3759,15 @@ pub mod nonlinearities {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2_f64;
|
||||
/// let result = recip(&x, k);
|
||||
/// let result = recip(&x, 1.0, k);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn recip(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
pub fn recip(a: &Tensor<i128>, input_scale: f64, out_scale: f64) -> Tensor<i128> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let denom = (1_f64) / (a_i as f64 + f64::EPSILON);
|
||||
let d_inv_x = scale * denom;
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as i128)
|
||||
})
|
||||
.unwrap()
|
||||
|
||||
@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.pad_to_zero_rem(n)?;
|
||||
*v = v.pad_to_zero_rem(n, pad)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
@@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
if indices.is_empty() {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
@@ -871,3 +875,30 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.map(|x| match x {
|
||||
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
ValType::AssignedValue(v.value_field().invert())
|
||||
}
|
||||
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
|
||||
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
|
||||
});
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,10 +33,7 @@ pub enum VarTensor {
|
||||
impl VarTensor {
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
47
src/wasm.rs
47
src/wasm.rs
@@ -72,7 +72,7 @@ pub fn encodeVerifierCalldata(
|
||||
/// Converts 4 u64s to a field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn vecU64ToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
pub fn stringToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(format!("{:?}", felt))
|
||||
@@ -81,7 +81,7 @@ pub fn vecU64ToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsE
|
||||
/// Converts 4 u64s representing a field element directly to an integer
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn vecU64ToInt(
|
||||
pub fn stringToInt(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
@@ -95,7 +95,7 @@ pub fn vecU64ToInt(
|
||||
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn vecU64ToFloat(
|
||||
pub fn stringToFloat(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
scale: crate::Scale,
|
||||
) -> Result<f64, JsError> {
|
||||
@@ -109,23 +109,23 @@ pub fn vecU64ToFloat(
|
||||
/// Converts a floating point element to 4 u64s representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn floatToVecU64(
|
||||
pub fn floatTostring(
|
||||
input: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = i128_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_vecu64_montgomery::<halo2curves::bn256::Fr>(&felt);
|
||||
let vec = crate::pfsys::field_to_string_montgomery::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize vecu64_montgomery{}", e)),
|
||||
|e| JsError::new(&format!("Failed to serialize string_montgomery{}", e)),
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn bufferToVecOfVecU64(
|
||||
pub fn bufferToVecOfstring(
|
||||
buffer: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
// Convert the buffer to a slice
|
||||
@@ -224,6 +224,7 @@ pub fn genWitness(
|
||||
pub fn genVk(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
// Read in kzg params
|
||||
let mut reader = std::io::BufReader::new(¶ms_ser[..]);
|
||||
@@ -235,9 +236,13 @@ pub fn genVk(
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?;
|
||||
|
||||
// Create verifying key
|
||||
let vk = create_vk_wasm::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, ¶ms)
|
||||
.map_err(Box::<dyn std::error::Error>::from)
|
||||
.map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?;
|
||||
let vk = create_vk_wasm::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(Box::<dyn std::error::Error>::from)
|
||||
.map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?;
|
||||
|
||||
let mut serialized_vk = Vec::new();
|
||||
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
|
||||
@@ -306,13 +311,19 @@ pub fn verify(
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
|
||||
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
snark,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -382,15 +393,6 @@ pub fn prove(
|
||||
.into_bytes())
|
||||
}
|
||||
|
||||
/// print hex representation of a proof
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
let hex_str = hex::encode(proof.proof);
|
||||
Ok(format!("0x{}", hex_str))
|
||||
}
|
||||
// VALIDATION FUNCTIONS
|
||||
|
||||
/// Witness file validation
|
||||
@@ -497,6 +499,7 @@ pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsErro
|
||||
pub fn create_vk_wasm<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
compress_selectors: bool,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
@@ -506,7 +509,7 @@ where
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize the verifying key
|
||||
let vk = keygen_vk(params, &empty_circuit)?;
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
|
||||
Ok(vk)
|
||||
}
|
||||
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
|
||||
|
||||
@@ -170,9 +170,9 @@ mod native_tests {
|
||||
}
|
||||
}
|
||||
|
||||
const PF_FAILURE: &str = "examples/test_failure.proof";
|
||||
const PF_FAILURE: &str = "examples/test_failure_proof.json";
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr.proof";
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 5] = [
|
||||
"self_attention",
|
||||
@@ -182,95 +182,96 @@ mod native_tests {
|
||||
"mnist_gan",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 5] = [
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
"accuracy",
|
||||
"1l_mlp",
|
||||
"4l_relu_conv_fc",
|
||||
"1l_elu",
|
||||
"1l_prelu",
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
"1l_mlp",
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad",
|
||||
"1l_pad", // 5
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax",
|
||||
"1l_softmax", //10
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu",
|
||||
"1l_relu", //15
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc",
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu", //30
|
||||
"min",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample", //35
|
||||
"1l_identity",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"idolmodel",
|
||||
"trig",
|
||||
"prelu_gmm",
|
||||
"lstm", //40
|
||||
"rnn",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // "variable_cnn",
|
||||
"decision_tree", // 45
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost", //50
|
||||
"lightgbm",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree",
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
"less",
|
||||
"less", //55
|
||||
"xgboost_reg",
|
||||
"1l_powf",
|
||||
"scatter_elements",
|
||||
"1l_linear", //60
|
||||
"linear_regression",
|
||||
"1l_linear",
|
||||
"linear_regression", //60
|
||||
"sklearn_mlp",
|
||||
"1l_mean",
|
||||
"rounding_ops",
|
||||
// "mean_as_constrain",
|
||||
"arange",
|
||||
"layernorm",
|
||||
"layernorm", //65
|
||||
"bitwise_ops",
|
||||
"blackman_window",
|
||||
"softsign", //70
|
||||
"softsign", //68
|
||||
"softplus",
|
||||
"selu",
|
||||
"selu", //70
|
||||
"hard_sigmoid",
|
||||
"log_softmax",
|
||||
"eye",
|
||||
"ltsf",
|
||||
"remainder",
|
||||
"remainder", //75
|
||||
"bitshift",
|
||||
];
|
||||
|
||||
@@ -360,7 +361,7 @@ mod native_tests {
|
||||
#[cfg(feature = "icicle")]
|
||||
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
|
||||
|
||||
const TESTS_EVM: [&str; 21] = [
|
||||
const TESTS_EVM: [&str; 23] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
@@ -376,12 +377,14 @@ mod native_tests {
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
"2l_relu_fc",
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"idolmodel",
|
||||
"1l_identity",
|
||||
"lstm",
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
];
|
||||
|
||||
const TESTS_EVM_AGGR: [&str; 18] = [
|
||||
@@ -474,6 +477,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
use rand::Rng;
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[test]
|
||||
@@ -487,13 +491,13 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
seq!(N in 0..=4 {
|
||||
seq!(N in 0..=5 {
|
||||
#(#[test_case(ACCURACY_CAL_TESTS[N])])*
|
||||
fn mock_accuracy_cal_tests(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -510,13 +514,23 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 1.2);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -526,7 +540,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 1.2);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -536,7 +550,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 1.2);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -547,7 +561,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -556,7 +570,18 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_tolerance_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -567,7 +592,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -576,7 +601,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -585,7 +610,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -594,7 +619,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -603,7 +628,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -612,7 +637,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -621,7 +646,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -631,7 +656,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -641,7 +666,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -650,7 +675,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -660,7 +685,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -669,7 +694,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -679,7 +704,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -689,7 +714,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None);
|
||||
mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -699,7 +724,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -709,7 +734,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -719,7 +744,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -824,10 +849,10 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
@@ -837,7 +862,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
test_dir.close().unwrap();
|
||||
@@ -853,7 +878,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single");
|
||||
kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -863,7 +888,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6]));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -880,7 +905,8 @@ mod native_tests {
|
||||
use crate::native_tests::TESTS_EVM_AGGR;
|
||||
use test_case::test_case;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify;
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::kzg_evm_prove_and_verify_render_seperately;
|
||||
|
||||
use crate::native_tests::kzg_evm_on_chain_input_prove_and_verify;
|
||||
use crate::native_tests::kzg_evm_aggr_prove_and_verify;
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
@@ -955,7 +981,7 @@ mod native_tests {
|
||||
});
|
||||
|
||||
|
||||
seq!(N in 0..= 17 {
|
||||
seq!(N in 0..=17 {
|
||||
// these take a particularly long time to run
|
||||
#(#[test_case(TESTS_EVM_AGGR[N])])*
|
||||
#[ignore]
|
||||
@@ -971,7 +997,7 @@ mod native_tests {
|
||||
});
|
||||
|
||||
|
||||
seq!(N in 0..= 20 {
|
||||
seq!(N in 0..=22 {
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_prove_and_verify_(test: &str) {
|
||||
@@ -979,9 +1005,22 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "private", "private", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS_EVM[N])])*
|
||||
fn kzg_evm_prove_and_verify_render_seperately_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -993,9 +1032,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "hashed", "private", "private");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1010,9 +1049,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "kzgcommit", "private", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1023,9 +1062,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "private", "hashed", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1036,9 +1075,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "private", "private", "hashed");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1049,9 +1088,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "private", "kzgcommit", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "kzgcommit", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1062,9 +1101,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "private", "private", "kzgcommit");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1074,9 +1113,9 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1246,6 +1285,7 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1257,6 +1297,8 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
tolerance,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1283,22 +1325,31 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: f32,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
"-M".to_string(),
|
||||
format!("{}/{}/network.onnx", test_dir, example_name),
|
||||
format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size->{}", batch_size),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
format!("--tolerance={}", tolerance),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-settings",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
&format!("--variables=batch_size={}", batch_size),
|
||||
&format!("--input-visibility={}", input_visibility),
|
||||
&format!("--param-visibility={}", param_visibility),
|
||||
&format!("--output-visibility={}", output_visibility),
|
||||
&format!("--num-inner-cols={}", num_inner_columns),
|
||||
])
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1367,6 +1418,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn accuracy_measurement(
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
@@ -1376,6 +1428,7 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1387,6 +1440,8 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
0.0,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1418,7 +1473,7 @@ mod native_tests {
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
"-O",
|
||||
format!("{}/{}/render.png", test_dir, example_name).as_str(),
|
||||
"--lookup-range=(-32768,32768)",
|
||||
"--lookup-range=-32768->32768",
|
||||
"-K=17",
|
||||
])
|
||||
.status()
|
||||
@@ -1645,6 +1700,8 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1707,6 +1764,30 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// load settings file
|
||||
let settings =
|
||||
std::fs::read_to_string(settings_path.clone()).expect("failed to read settings file");
|
||||
|
||||
let graph_settings = serde_json::from_str::<GraphSettings>(&settings)
|
||||
.expect("failed to parse settings file");
|
||||
|
||||
// get_srs for the graph_settings_num_instances
|
||||
download_srs(graph_settings.log2_total_instances());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
"--reduced-srs",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
@@ -1721,6 +1802,8 @@ mod native_tests {
|
||||
"resources",
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -1740,6 +1823,7 @@ mod native_tests {
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
fn kzg_evm_prove_and_verify(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
input_visibility: &str,
|
||||
@@ -1755,7 +1839,7 @@ mod native_tests {
|
||||
input_visibility,
|
||||
param_visibility,
|
||||
output_visibility,
|
||||
2,
|
||||
num_inner_columns,
|
||||
None,
|
||||
false,
|
||||
"single",
|
||||
@@ -1830,6 +1914,137 @@ mod native_tests {
|
||||
assert!(!status.success());
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
fn kzg_evm_prove_and_verify_render_seperately(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
example_name: String,
|
||||
input_visibility: &str,
|
||||
param_visibility: &str,
|
||||
output_visibility: &str,
|
||||
) {
|
||||
let anvil_url = ANVIL_URL.as_str();
|
||||
|
||||
kzg_prove_and_verify(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
"safe",
|
||||
input_visibility,
|
||||
param_visibility,
|
||||
output_visibility,
|
||||
num_inner_columns,
|
||||
None,
|
||||
false,
|
||||
"single",
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
init_params(settings_path.clone().into());
|
||||
|
||||
let vk_arg = format!("{}/{}/key.vk", test_dir, example_name);
|
||||
let rpc_arg = format!("--rpc-url={}", anvil_url);
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name);
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
let sol_arg = format!("--sol-code-path={}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--render-vk-seperately",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
let args = vec![
|
||||
"create-evm-vk",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg_vk,
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm-verifier",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr = std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
|
||||
// deploy the vk
|
||||
let args = vec![
|
||||
"deploy-evm-vk",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg_vk.as_str(),
|
||||
sol_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address
|
||||
let addr_vk = std::fs::read_to_string(format!("{}/{}/addr_vk.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let mut args = vec![
|
||||
"verify-evm",
|
||||
"--proof-path",
|
||||
pf_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
deployed_addr_arg.as_str(),
|
||||
deployed_addr_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
args[2] = PF_FAILURE;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
|
||||
let status = Command::new("pnpm")
|
||||
@@ -1862,8 +2077,10 @@ mod native_tests {
|
||||
1,
|
||||
"resources",
|
||||
// we need the accuracy
|
||||
Some(vec![7, 8]),
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
0.0,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
@@ -1901,6 +2118,20 @@ mod native_tests {
|
||||
.map(|h| vec![FileSourceInner::Field(*h)])
|
||||
.collect(),
|
||||
));
|
||||
} else {
|
||||
input.output_data = Some(DataSource::File(
|
||||
witness
|
||||
.pretty_elements
|
||||
.unwrap()
|
||||
.rescaled_outputs
|
||||
.iter()
|
||||
.map(|o| {
|
||||
o.iter()
|
||||
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
|
||||
@@ -12,7 +12,7 @@ def get_ezkl_output(witness_file, settings_file):
|
||||
outputs = witness_output['outputs']
|
||||
with open(settings_file) as f:
|
||||
settings = json.load(f)
|
||||
ezkl_outputs = [[ezkl.vecu64_to_float(
|
||||
ezkl_outputs = [[ezkl.string_to_float(
|
||||
outputs[i][j], settings['model_output_scales'][i]) for j in range(len(outputs[i]))] for i in range(len(outputs))]
|
||||
return ezkl_outputs
|
||||
|
||||
@@ -78,14 +78,20 @@ def compare_outputs(zk_output, onnx_output):
|
||||
|
||||
zip_object = zip(np.array(zk_output).flatten(),
|
||||
np.array(onnx_output).flatten())
|
||||
for list1_i, list2_i in zip_object:
|
||||
for (i, (list1_i, list2_i)) in enumerate(zip_object):
|
||||
if list1_i == 0.0 and list2_i == 0.0:
|
||||
res.append(0)
|
||||
else:
|
||||
diff = list1_i - list2_i
|
||||
res.append(100 * (diff) / (list2_i))
|
||||
# iterate and print the diffs if they are greater than 0.0
|
||||
if abs(diff) > 0.0:
|
||||
print("------- index: ", i)
|
||||
print("------- diff: ", diff)
|
||||
print("------- zk_output: ", list1_i)
|
||||
print("------- onnx_output: ", list2_i)
|
||||
|
||||
|
||||
print("res: ", res)
|
||||
|
||||
return np.mean(np.abs(res))
|
||||
|
||||
|
||||
@@ -118,38 +118,38 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 32] = [
|
||||
"proof_splitting.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
"hashed_vis.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb", // 4
|
||||
"hashed_vis.ipynb", // 5
|
||||
"simple_demo_all_public.ipynb",
|
||||
"data_attest.ipynb",
|
||||
"variance.ipynb",
|
||||
"little_transformer.ipynb",
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
"xgboost.ipynb",
|
||||
"lightgbm.ipynb",
|
||||
"svm.ipynb",
|
||||
"simple_demo_public_input_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb", // 20
|
||||
"gcn.ipynb",
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"kzg_vis.ipynb",
|
||||
"kzg_vis.ipynb", // 25
|
||||
"kmeans.ipynb",
|
||||
"solvency.ipynb",
|
||||
"sklearn_mlp.ipynb",
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
];
|
||||
|
||||
|
||||
@@ -56,9 +56,9 @@ def test_poseidon_hash():
|
||||
Test for poseidon_hash
|
||||
"""
|
||||
message = [1.0, 2.0, 3.0, 4.0]
|
||||
message = [ezkl.float_to_vecu64(x, 7) for x in message]
|
||||
message = [ezkl.float_to_string(x, 7) for x in message]
|
||||
res = ezkl.poseidon_hash(message)
|
||||
assert ezkl.vecu64_to_felt(
|
||||
assert ezkl.string_to_felt(
|
||||
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
|
||||
|
||||
|
||||
@@ -70,14 +70,14 @@ def test_field_serialization():
|
||||
|
||||
input = 890
|
||||
scale = 7
|
||||
felt = ezkl.float_to_vecu64(input, scale)
|
||||
roundtrip_input = ezkl.vecu64_to_float(felt, scale)
|
||||
felt = ezkl.float_to_string(input, scale)
|
||||
roundtrip_input = ezkl.string_to_float(felt, scale)
|
||||
assert input == roundtrip_input
|
||||
|
||||
input = -700
|
||||
scale = 7
|
||||
felt = ezkl.float_to_vecu64(input, scale)
|
||||
roundtrip_input = ezkl.vecu64_to_float(felt, scale)
|
||||
felt = ezkl.float_to_string(input, scale)
|
||||
roundtrip_input = ezkl.string_to_float(felt, scale)
|
||||
assert input == roundtrip_input
|
||||
|
||||
|
||||
@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -147,13 +147,13 @@ def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -183,7 +183,7 @@ def test_model_compile():
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'network.onnx'
|
||||
)
|
||||
compiled_model_path = os.path.join(
|
||||
@@ -205,7 +205,7 @@ def test_forward():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_average',
|
||||
'1l_relu',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
@@ -392,9 +392,7 @@ def test_prove_evm():
|
||||
assert res['transcript_type'] == 'EVM'
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
res = ezkl.print_proof_hex(proof_path)
|
||||
# to figure out a better way of testing print_proof_hex
|
||||
assert type(res) == str
|
||||
|
||||
|
||||
|
||||
def test_create_evm_verifier():
|
||||
|
||||
@@ -8,10 +8,10 @@ mod wasm32 {
|
||||
use ezkl::graph::GraphWitness;
|
||||
use ezkl::pfsys;
|
||||
use ezkl::wasm::{
|
||||
bufferToVecOfVecU64, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
|
||||
prove, settingsValidation, srsValidation, u8_array_to_u128_le, vecU64ToFelt, vecU64ToFloat,
|
||||
vecU64ToInt, verify, vkValidation, witnessValidation,
|
||||
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
|
||||
genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove,
|
||||
settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
|
||||
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
@@ -26,7 +26,7 @@ mod wasm32 {
|
||||
pub const NETWORK_COMPILED: &[u8] = include_bytes!("../tests/wasm/model.compiled");
|
||||
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
|
||||
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
|
||||
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/test.proof");
|
||||
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
|
||||
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
|
||||
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
|
||||
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
|
||||
@@ -78,19 +78,19 @@ mod wasm32 {
|
||||
let serialized = serde_json::to_vec(&field_element).unwrap();
|
||||
let clamped = wasm_bindgen::Clamped(serialized);
|
||||
let scale = 2;
|
||||
let floating_point = vecU64ToFloat(clamped.clone(), scale)
|
||||
let floating_point = stringToFloat(clamped.clone(), scale)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
assert_eq!(floating_point, (i as f64) / 4.0);
|
||||
|
||||
let integer: i128 = serde_json::from_slice(
|
||||
&vecU64ToInt(clamped.clone()).map_err(|_| "failed").unwrap(),
|
||||
&stringToInt(clamped.clone()).map_err(|_| "failed").unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(integer, i as i128);
|
||||
|
||||
let hex_string = format!("{:?}", field_element);
|
||||
let returned_string = vecU64ToFelt(clamped).map_err(|_| "failed").unwrap();
|
||||
let returned_string = stringToFelt(clamped).map_err(|_| "failed").unwrap();
|
||||
assert_eq!(hex_string, returned_string);
|
||||
}
|
||||
}
|
||||
@@ -101,7 +101,7 @@ mod wasm32 {
|
||||
let mut buffer = string_high.clone().into_bytes();
|
||||
let clamped = wasm_bindgen::Clamped(buffer.clone());
|
||||
|
||||
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
|
||||
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
|
||||
|
||||
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
|
||||
|
||||
@@ -118,7 +118,7 @@ mod wasm32 {
|
||||
let buffer = string_sample.clone().into_bytes();
|
||||
let clamped = wasm_bindgen::Clamped(buffer.clone());
|
||||
|
||||
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
|
||||
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
|
||||
|
||||
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
|
||||
|
||||
@@ -133,7 +133,7 @@ mod wasm32 {
|
||||
let buffer = string_concat.into_bytes();
|
||||
let clamped = wasm_bindgen::Clamped(buffer.clone());
|
||||
|
||||
let field_elements_ser = bufferToVecOfVecU64(clamped).map_err(|_| "failed").unwrap();
|
||||
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
|
||||
|
||||
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
|
||||
|
||||
@@ -186,6 +186,7 @@ mod wasm32 {
|
||||
let vk = genVk(
|
||||
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
|
||||
wasm_bindgen::Clamped(SRS.to_vec()),
|
||||
true,
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
@@ -206,6 +207,7 @@ mod wasm32 {
|
||||
let vk = genVk(
|
||||
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
|
||||
wasm_bindgen::Clamped(SRS.to_vec()),
|
||||
true,
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
@@ -218,6 +220,7 @@ mod wasm32 {
|
||||
let vk = genVk(
|
||||
wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()),
|
||||
wasm_bindgen::Clamped(SRS.to_vec()),
|
||||
true,
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
@@ -255,15 +258,6 @@ mod wasm32 {
|
||||
assert!(value);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn print_proof_hex_test() {
|
||||
let proof = printProofHex(wasm_bindgen::Clamped(PROOF.to_vec()))
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert!(proof.len() > 0);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn verify_validations() {
|
||||
// Run witness validation on network (should fail)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -1 +1,61 @@
|
||||
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":0,"scale_rebase_multiplier":10,"lookup_range":[-2,0],"logrows":6,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_rows":16,"total_assignments":32,"total_const_size":8,"model_instance_shapes":[[1,4]],"model_output_scales":[0],"model_input_scales":[0],"module_sizes":{"kzg":[],"poseidon":[0,[0]]},"required_lookups":["ReLU"],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1702474230544}
|
||||
{
|
||||
"run_args": {
|
||||
"tolerance": {
|
||||
"val": 0.0,
|
||||
"scale": 1.0
|
||||
},
|
||||
"input_scale": 0,
|
||||
"param_scale": 0,
|
||||
"scale_rebase_multiplier": 10,
|
||||
"lookup_range": [
|
||||
-2,
|
||||
0
|
||||
],
|
||||
"logrows": 6,
|
||||
"num_inner_cols": 2,
|
||||
"variables": [
|
||||
[
|
||||
"batch_size",
|
||||
1
|
||||
]
|
||||
],
|
||||
"input_visibility": "Private",
|
||||
"output_visibility": "Public",
|
||||
"param_visibility": "Private",
|
||||
"div_rebasing": false,
|
||||
"rebase_frac_zero_constants": false,
|
||||
"check_mode": "UNSAFE"
|
||||
},
|
||||
"num_rows": 16,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
[
|
||||
1,
|
||||
4
|
||||
]
|
||||
],
|
||||
"model_output_scales": [
|
||||
0
|
||||
],
|
||||
"model_input_scales": [
|
||||
0
|
||||
],
|
||||
"module_sizes": {
|
||||
"kzg": [],
|
||||
"poseidon": [
|
||||
0,
|
||||
[
|
||||
0
|
||||
]
|
||||
]
|
||||
},
|
||||
"required_lookups": [
|
||||
"ReLU"
|
||||
],
|
||||
"required_range_checks": [],
|
||||
"check_mode": "UNSAFE",
|
||||
"version": "0.0.0",
|
||||
"num_blinding_factors": null,
|
||||
"timestamp": 1702474230544
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,6 @@
|
||||
import * as fs from 'fs/promises';
|
||||
import * as fsSync from 'fs'
|
||||
import JSONBig from 'json-bigint';
|
||||
import { vecU64ToFelt } from '@ezkljs/engine/nodejs'
|
||||
const solc = require('solc');
|
||||
|
||||
// import os module
|
||||
|
||||
Binary file not shown.
@@ -1 +1 @@
|
||||
{"inputs":[[[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574],[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287],[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287]]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[[[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1}
|
||||
Reference in New Issue
Block a user