Compare commits

...

37 Commits

Author SHA1 Message Date
dante
523c77c912 feat: lookupless sqrt and rsqrt (#867) 2024-11-10 15:56:38 +00:00
dante
948e5cd4b9 chore: version proof and witness (#865) 2024-11-08 02:55:35 +00:00
dante
00155e585f feat: bounded lookup log argument (#864) 2024-11-07 12:16:55 +00:00
dante
0876faa12c feat: bounded lookup round half to even (#863) 2024-11-01 00:51:15 -04:00
dante
a3c131dac0 feat: lookupless rounding ops (#862) 2024-10-31 11:29:46 -04:00
sebastiandanconia
fd9c2305ac docs: improve cli friendliness (#861)
* Improve clarity of an info!() message

* Replace references to EZKL_REPO_PATH in `--help' output

Command `--help' messages aren't meant to be unduly verbose; we can
write them for common/simple use cases. We continue to support
EZKL_REPO_PATH for users who need it, for example to support
containerized server use cases.

To be clear, by default, EZKL_REPO_PATH = $HOME/.ezkl
2024-10-30 17:25:47 -04:00
dante
a0060f341d chore: rm lookup recip (#859) 2024-10-29 15:57:38 -04:00
dante
17f1d42739 chore: unify leakyrelu and relu (#858) 2024-10-29 10:43:40 -04:00
dante
ebaee9e2b1 feat: lookupless min/max ops (#854) 2024-10-26 08:00:27 -04:00
dante
d51cba589a feat: dynamic lookup overflow (#853) 2024-10-23 23:12:00 -04:00
Artem
1cb1b6e143 feat: iOS Bindings (#846) 2024-10-23 09:58:55 -04:00
Ethan Cemer
d2b683b527 feat: reusable verifier (#821) 2024-10-22 09:10:24 -04:00
Jseam
a06b09ef1f docs: add batch demo (#849) 2024-10-15 08:36:24 +01:00
dante
e5aa48fbd6 chore: support all padding types (#848) 2024-10-05 10:43:12 -04:00
dante
64fbc8a1c9 refactor: lookup-less sign relu abs (#845) 2024-09-17 11:58:58 -04:00
dante
c9f9d17f16 chore: optimized release builds by default (#844) 2024-09-05 15:52:03 +02:00
Ethan Cemer
b49b0487c4 fix: remove console import from index.ts in in-browser-evm-verifier package (#841) 2024-08-30 18:47:59 +01:00
dante
61b7a8e9b5 chore: perf updates (#838) 2024-08-27 09:45:40 -04:00
dante
5dbc7d5176 chore: cache lookup tables (#835) 2024-08-19 00:24:53 -04:00
dante
ada45a3197 chore: swap h2 collections hash for rustc hasher (#832) 2024-08-04 17:28:36 -04:00
dante
616b421967 fix: bump compiler to latest to accomodate latest serde diagnostic (#830) 2024-07-25 07:56:21 -04:00
dante
f64f0ecd23 fix: instance order when using processed params (#829) 2024-07-24 07:58:46 -04:00
dante
5be12b7a54 fix: num groups for conv operations should be specified at load time (#828) 2024-07-18 09:58:41 -04:00
dante
2fd877c716 chore: small worm example (#568) 2024-07-15 09:20:37 -04:00
dante
8197340985 chore: const filtering optimizations (#825) 2024-07-12 12:37:02 +01:00
dante
6855ea1947 feat: parallel polynomial reads in halo2 (#826) 2024-07-12 00:33:14 +01:00
dante
2ca57bde2c chore: bump tract (#823) 2024-06-29 01:53:18 +01:00
Ethan Cemer
390de88194 feat: create_evm_vk python bindings (#818) 2024-06-23 22:21:59 -04:00
dante
cd91f0af26 chore: allow for float lookup safety margins (#817) 2024-06-19 09:12:01 -04:00
dante
4771192823 fix: more verbose io / rw errors (#815) 2024-06-14 12:31:50 -04:00
dante
a863ccc868 chore: update cmd feature flag (#814) 2024-06-11 16:07:49 -04:00
Ethan Cemer
8e6ccc863d feat: all file source kzg commit DA (#812) 2024-06-11 09:32:54 -04:00
dante
00d6873f9a fix: should update using bash when possible (#813) 2024-06-10 12:13:59 -04:00
dante
c97ff84198 refactor: rm boxed errors (opaque) (#810) 2024-06-08 22:41:47 -04:00
dante
f5f8ef56f7 chore: ezkl self update (#809) 2024-06-07 10:30:21 -04:00
Ethan Cemer
685487c853 feat: DA swap proof commitments (#807) 2024-06-07 10:19:38 -04:00
dante
33d41c7e49 fix: calldata cmd 2024-05-31 13:19:33 -04:00
137 changed files with 12188 additions and 6489 deletions

View File

@@ -22,15 +22,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -40,7 +40,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
@@ -86,7 +86,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Checkout repo
@@ -106,27 +106,27 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: aarch64-unknown-linux-gnu
steps:
@@ -181,9 +181,14 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary
- name: Build release binary (no asm)
if: matrix.build != 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -65,40 +65,40 @@ jobs:
- name: Library tests (original lookup)
run: cargo nextest run --lib --verbose --no-default-features --features ezkl
ultra-overflow-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run --release lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# ultra-overflow-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - uses: mwilliamson/setup-wasmtime-action@v2
# with:
# wasmtime-version: "3.0.1"
# - name: Install wasm32-wasi
# run: rustup target add wasm32-wasi
# - name: Install cargo-wasi
# run: cargo install cargo-wasi
# # - name: Matmul overflow (wasi)
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# # - name: Conv overflow (wasi)
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
# - name: lookup overflow
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Matmul overflow
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv overflow
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv + relu overflow
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
ultra-overflow-tests_og-lookup:
runs-on: non-gpu
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -184,22 +184,24 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- uses: nanasess/setup-chromedriver@v2
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -212,7 +214,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -229,17 +231,21 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and bounded lookup log
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params
@@ -258,6 +264,8 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
@@ -286,7 +294,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -304,7 +312,7 @@ jobs:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -323,12 +331,12 @@ jobs:
cd in-browser-evm-verifier
pnpm build:commonjs
cd ..
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -341,6 +349,12 @@ jobs:
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)
@@ -359,15 +373,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -425,40 +442,40 @@ jobs:
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
prove-and-verify-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
# prove-and-verify-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
# - uses: actions/checkout@v3
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (kzg outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public inputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# - name: KZG prove and verify tests (fixed params)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# - name: KZG prove and verify tests (hashed outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
@@ -467,7 +484,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -485,7 +502,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -493,7 +510,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
@@ -502,7 +519,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -519,17 +536,17 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -540,7 +557,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -560,17 +577,17 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
@@ -586,7 +603,7 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -635,17 +652,17 @@ jobs:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
@@ -671,3 +688,68 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Reusable verifier tutorial
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
ios-integration-tests:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
runs-on: macos-latest
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

View File

@@ -0,0 +1,75 @@
name: Build and Publish EZKL iOS SPM package
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
jobs:
build-and-update:
runs-on: macos-latest
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Commit and Push Changes to feat/ezkl-direct-integration
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git add Sources/EzklCoreBindings
git commit -m "Automatically updated EzklCoreBindings for EZKL"
git tag ${{ github.event.inputs.tag }}
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
git push origin
git push origin --tags
env:
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}

6
.gitignore vendored
View File

@@ -46,7 +46,7 @@ var/
node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/assets/pk.key
!tests/assets/vk.key
docs/python/build
!tests/wasm/vk_aggr.key
!tests/assets/vk_aggr.key

1025
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ cargo-features = ["profile-rustflags"]
name = "ezkl"
version = "0.0.0"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -11,90 +12,93 @@ edition = "2021"
# Name to be imported within python
# Example: import ezkl
name = "ezkl"
crate-type = ["cdylib", "rlib"]
crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.5.3", features = ["derive"] }
clap_complete = "4.5.2"
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
"raw_value",
], optional = true }
log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = ["circuit-params"] }
rand = { version = "0.8", default-features = false }
itertools = { version = "0.10.3", default-features = false }
clap = { version = "4.5.3", features = ["derive"], optional = true }
serde = { version = "1.0.126", features = ["derive"] }
clap_complete = { version = "4.5.2", optional = true }
log = { version = "0.4.17", default-features = false }
thiserror = { version = "1.0.38", default-features = false }
hex = { version = "0.4.3", default-features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
portable-atomic = { version = "1.6.0", optional = true }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
serde_json = { version = "1.0.97", features = [
"float_roundtrip",
"raw_value",
] }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
ethabi = "18"
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
"provider-http",
"signers",
"contract",
"rpc-types-eth",
"signer-wallet",
"node-bindings",
], optional = true }
foundry-compilers = { version = "0.4.1", features = ["svm-solc"], optional = true }
ethabi = { version = "18", optional = true }
indicatif = { version = "0.17.5", features = ["rayon"], optional = true }
gag = { version = "1.0.0", default-features = false, optional = true }
instant = { version = "0.1" }
reqwest = { version = "0.12.4", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
futures-util = "0.3.30"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.35", default_features = false, features = [
"macros",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
reqwest = { version = "0.12.4", default-features = false, features = ["default-tls", "multipart", "stream"], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
regex = { version = "1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = ["macros", "rt-multi-thread"], optional = true }
pyo3 = { version = "0.21.2", features = ["extension-module", "abi3-py37", "macros"], default-features = false, optional = true }
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = ["attributes", "tokio-runtime"], default-features = false, optional = true }
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
uniffi_bindgen = { version = "=0.28.0", optional = true }
camino = { version = "^1.1", optional = true }
uuid = { version = "1.10.0", features = ["v4"], optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true }
env_logger = { version = "0.10.0", default_features = false, optional = true }
chrono = "0.4.31"
sha256 = "1.4.0"
colored = { version = "2.0.0", default-features = false, optional = true }
env_logger = { version = "0.10.0", default-features = false, optional = true }
chrono = { version = "0.4.31", optional = true }
sha256 = { version = "1.4.0", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
serde_json = { version = "1.0.97", default-features = false, features = [
"float_roundtrip",
"raw_value",
] }
getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
@@ -107,8 +111,14 @@ console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
[build-dependencies]
uniffi = { version = "0.28", features = ["build"], optional = true }
[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -159,16 +169,20 @@ harness = false
[[bench]]
name = "relu"
name = "sigmoid"
harness = false
[[bench]]
name = "accum_matmul_relu"
name = "relu_lookupless"
harness = false
[[bench]]
name = "accum_matmul_sigmoid"
harness = false
[[bench]]
name = "accum_matmul_relu_overflow"
name = "accum_matmul_sigmoid_overflow"
harness = false
[[bin]]
@@ -177,40 +191,73 @@ test = false
bench = false
required-features = ["ezkl"]
[[bin]]
name = "ios_gen_bindings"
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup", "no-banner"]
default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",
"serde",
"serde_json",
"log",
"colored",
"env_logger",
"dep:colored",
"dep:env_logger",
"tabled/color",
"serde_json/std",
"colored_json",
"halo2_proofs/circuit-params",
"dep:alloy",
"dep:foundry-compilers",
"dep:ethabi",
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:openssl",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:regex",
"dep:tokio",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:portable-atomic",
"dep:clap_complete",
"dep:halo2_solidity_verifier",
"dep:semver",
"dep:clap",
"dep:tosubcommand",
]
parallel-poly-read = ["halo2_proofs/circuit-params", "halo2_proofs/parallel-poly-read"]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
metal = ["dep:metal", "dep:objc"]
no-update = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"

View File

@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
K,
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -75,14 +83,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
k,
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -76,14 +84,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();

View File

@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();

150
benches/relu_lookupless.rs Normal file
View File

@@ -0,0 +1,150 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut LEN: usize = 4;
const K: usize = 16;
#[derive(Clone)]
struct NLCircuit {
pub input: ValTensor<Fr>,
}
impl Circuit<Fr> for NLCircuit {
type Config = Config<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unsafe {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let mut config = Config::default();
config
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
.unwrap();
config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
config
}
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_range_checks(&mut layouter).unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runrelu(c: &mut Criterion) {
let mut group = c.benchmark_group("relu");
let mut rng = rand::thread_rng();
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 8].iter() {
unsafe {
LEN = len;
};
let input: Tensor<Value<Fr>> =
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runrelu
}
criterion_main!(benches);

View File

@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -41,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
let mut config = Config::default();
@@ -62,9 +63,13 @@ impl Circuit<Fr> for NLCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},
@@ -84,7 +89,7 @@ fn runrelu(c: &mut Criterion) {
};
let input: Tensor<Value<Fr>> =
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),

7
build.rs Normal file
View File

@@ -0,0 +1,7 @@
fn main() {
if cfg!(feature = "ios-bindings-test") {
println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features");
}
println!("cargo::rerun-if-changed=build.rs");
}

View File

@@ -93,6 +93,76 @@ contract LoadInstances {
}
}
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
bytes constant COMMITMENT_KZG = hex"";
contract SwapProofCommitments {
/**
* @dev Swap the proof commitments
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function checkKzgCommits(
bytes calldata encoded
) internal pure returns (bool equal) {
bytes4 funcSig;
uint256 proof_offset;
uint256 proof_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch proof offset which is 4 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
proof_offset := calldataload(
add(
encoded.offset,
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
proof_length := calldataload(
add(add(encoded.offset, 0x04), proof_offset)
)
}
// Check the length of the commitment against the proof bytes
if (proof_length < COMMITMENT_KZG.length) {
return false;
}
// Load COMMITMENT_KZG into memory
bytes memory commitment = COMMITMENT_KZG;
// Compare the first N bytes of the proof with COMMITMENT_KZG
uint words = (commitment.length + 31) / 32; // Calculate the number of 32-byte words
assembly {
// Now we compare the commitment with the proof,
// ensuring that the commitments divided up into 32 byte words are all equal.
for {
let i := 0x20
} lt(i, add(mul(words, 0x20), 0x20)) {
i := add(i, 0x20)
} {
let wordProof := calldataload(
add(add(encoded.offset, add(i, 0x04)), proof_offset)
)
let wordCommitment := mload(add(commitment, i))
equal := eq(wordProof, wordCommitment)
if eq(equal, 0) {
return(0, 0)
}
}
}
return equal; // Return true if the commitment comparison passed
} /// end checkKzgCommits
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
@@ -104,9 +174,10 @@ contract LoadInstances {
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestation is LoadInstances {
contract DataAttestation is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
@@ -350,12 +421,18 @@ contract DataAttestation is LoadInstances {
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);

View File

@@ -2,8 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
@@ -42,8 +41,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -66,8 +65,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -90,8 +89,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -147,6 +146,8 @@ where
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::configure(
@@ -157,15 +158,11 @@ where
);
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();
layer_config
@@ -196,15 +193,21 @@ where
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config
@@ -221,7 +224,14 @@ where
let x = config
.layer_config
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
let mut x = config
@@ -281,7 +291,7 @@ where
}
pub fn runconv() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
const KERNEL_HEIGHT: usize = 5;
@@ -315,7 +325,11 @@ pub fn runconv() {
.test_set_length(10_000)
.finalize();
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
let mut train_data = Tensor::from(
trn_img
.iter()
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
);
train_data.reshape(&[50_000, 28, 28]).unwrap();
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
@@ -343,8 +357,8 @@ pub fn runconv() {
.map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}),
);
@@ -355,7 +369,8 @@ pub fn runconv() {
let l0_kernels = l0_kernels.try_into().unwrap();
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
let mut l0_bias =
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let l0_bias = l0_bias.try_into().unwrap();
@@ -363,8 +378,8 @@ pub fn runconv() {
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
@@ -374,8 +389,8 @@ pub fn runconv() {
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
@@ -401,13 +416,13 @@ pub fn runconv() {
l2_params: [l2_weights, l2_biases],
};
let public_input: Tensor<i32> = vec![
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
let public_input: Tensor<IntegerRep> = vec![
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
]
.into_iter()
.into();
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
println!("MOCK PROVING");
let now = Instant::now();

View File

@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
let output = VarTensor::new_advice(cs, K, 1, LEN);
// tells the config layer to add an affine op to the circuit gate
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::<F>::configure(
cs,
&[input.clone(), params.clone()],
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
CheckMode::SAFE,
);
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
@@ -104,11 +103,16 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let x = config
.layer_config
.layout(
@@ -141,7 +145,14 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
println!("x shape: {:?}", x.dims());
let mut x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap()
.unwrap();
println!("3");
@@ -177,7 +188,14 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
println!("x shape: {:?}", x.dims());
let x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();
println!("6");
println!("offset: {}", region.row());
@@ -212,36 +230,36 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
}
pub fn runmlp() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
// parameters
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
// input data, with 1 padding to allow for bias
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
.unwrap()
.into();
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
let circuit = MyCircuit::<4, -8192, 8192> {
@@ -251,12 +269,12 @@ pub fn runmlp() {
_marker: PhantomData,
};
let public_input: Vec<i32> = unsafe {
let public_input: Vec<IntegerRep> = unsafe {
vec![
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
(2849f32 / 128f32).to_int_unchecked::<i32>(),
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
]
};
@@ -265,7 +283,10 @@ pub fn runmlp() {
let prover = MockProver::run(
K as u32,
&circuit,
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
vec![public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect()],
)
.unwrap();
prover.assert_satisfied();

View File

@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -648,10 +648,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -0,0 +1,604 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-kzg-vis\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source and the params and outputs are committed to using kzg-commitments. \n",
"\n",
"In this setup:\n",
"- the inputs and outputs are publicly known to the prover and verifier\n",
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n",
"- The kzg commitment to the params and inputs will be read from the proof and checked to make sure it matches the expected commitment stored on-chain.\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"polycommitment\" \n",
"- `output_visibility`: \"polycommitment\"\n",
"\n",
"**Note**:\n",
"When we set this to polycommitment, we are saying that the model parameters are committed to using a polynomial commitment scheme. This commitment will be stored on chain as a constant stored in the DA contract, and the proof will contain the commitment to the parameters. The DA verification will then check that the commitment in the proof matches the commitment stored on chain. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"polycommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
" data_path,\n",
" input_source=ezkl.PyTestDataSource.OnChain,\n",
" output_source=ezkl.PyTestDataSource.File,\n",
" rpc_url=RPC_URL)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"When deploying a DA with kzg commitments, we need to make sure to also pass a witness file that contains the commitments to the parameters and inputs. This is because the verifier will need to check that the commitments in the proof match the commitments stored on chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" witness_path = witness_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.7"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -232,7 +232,7 @@
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n",
"run_args.logrows = 15\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
]
@@ -404,7 +404,7 @@
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n"
"run_args.logrows = 15\n"
]
},
{
@@ -466,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reusable Verifiers \n",
"\n",
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
"\n",
"- `1l_mlp sigmoid`\n",
"- `1l_mlp relu`\n",
"- `1l_conv sigmoid`\n",
"- `1l_conv relu`\n",
"\n",
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
"\n",
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we deploy a unique and much smaller verifying key artifact (VKA) contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.onnx\n",
"\n",
"# Define the models\n",
"class MLP_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Sigmoid, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class MLP_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Relu, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"class Conv_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Sigmoid, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class Conv_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Relu, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"# Instantiate the models\n",
"mlp_sigmoid = MLP_Sigmoid()\n",
"mlp_relu = MLP_Relu()\n",
"conv_sigmoid = Conv_Sigmoid()\n",
"conv_relu = Conv_Relu()\n",
"\n",
"# Dummy input tensor for mlp\n",
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
"input_mlp_path = 'mlp_input.json'\n",
"\n",
"# Dummy input tensor for conv\n",
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
"input_conv_path = 'conv_input.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import ezkl\n",
"\n",
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
" # Create a new directory for the model if it doesn't exist\n",
" if not os.path.exists(name):\n",
" os.mkdir(name)\n",
" # Store the paths in each of their respective directories\n",
" model_path = os.path.join(name, \"network.onnx\")\n",
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
" pk_path = os.path.join(name, \"test.pk\")\n",
" vk_path = os.path.join(name, \"test.vk\")\n",
" settings_path = os.path.join(name, \"settings.json\")\n",
"\n",
" witness_path = os.path.join(name, \"witness.json\")\n",
" sol_code_path = os.path.join(name, 'test.sol')\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" abi_path = os.path.join(name, 'test.abi')\n",
" proof_path = os.path.join(name, \"proof.json\")\n",
"\n",
" # Flips the neural net into inference mode\n",
" model.eval()\n",
"\n",
" # Export the model\n",
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
" do_constant_folding=True, input_names=['input'],\n",
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
" data = dict(input_data=[data_array])\n",
" json.dump(data, open(input_path, 'w'))\n",
"\n",
" py_run_args = ezkl.PyRunArgs()\n",
" py_run_args.input_visibility = \"private\"\n",
" py_run_args.output_visibility = \"public\"\n",
" py_run_args.param_visibility = \"fixed\" # private by default\n",
"\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
" assert res == True\n",
"\n",
" await ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
"\n",
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
" assert res == True\n",
"\n",
" res = await ezkl.get_srs(settings_path)\n",
" assert res == True\n",
"\n",
" # now generate the witness file\n",
" res = await ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
" assert os.path.isfile(witness_path) == True\n",
"\n",
" # SETUP \n",
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" assert os.path.isfile(settings_path)\n",
"\n",
" # GENERATE A PROOF\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
" assert res == True\n",
"\n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, sol_key_code_path, abi_path)\n",
" assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"start_anvil()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import filecmp\n",
"\n",
"def compare_files(file1, file2):\n",
" return filecmp.cmp(file1, file2, shallow=False)\n",
"\n",
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
"\n",
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
"\n",
"\n",
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we deploy separate verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(addr_path_verifier, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
" addr_vk = file.read().rstrip()\n",
" \n",
" proof_path = os.path.join(name, \"proof.json\")\n",
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" addr_vk = addr_vk\n",
" )\n",
" assert res == True"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -482,7 +482,7 @@
"source": [
"import pytest\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path_faulty,\n",
" settings_path,\n",
@@ -514,9 +514,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "171702d3",
"metadata": {},
"outputs": [],
@@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "671dfdd5",
"metadata": {},
"outputs": [],
@@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "50eba2f4",
"metadata": {},
"outputs": [],
@@ -399,9 +399,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -478,12 +478,11 @@
"import pytest\n",
"\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"# Run the test function\n",
@@ -510,9 +509,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -39,7 +39,7 @@
"import json\n",
"import numpy as np\n",
"from sklearn.svm import SVC\n",
"import sk2torch\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
@@ -59,11 +59,11 @@
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = SVC(probability=True)\n",
"sk_model.fit(xs, ys)\n",
"model = sk2torch.wrap(sk_model)\n",
"\n",
"model = convert(sk_model, \"torch\").model\n",
"\n",
"\n",
"\n",
"model\n",
"\n"
]
},
@@ -84,33 +84,6 @@
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ca328",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"\n",
"# Compute the gradients of the SVM output.\n",
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
"\n",
"\n",
"# Create a quiver plot of the vector field.\n",
"plt.quiver(\n",
" grid_xs[:, 0].detach().numpy(),\n",
" grid_xs[:, 1].detach().numpy(),\n",
" input_grads[:, 0].detach().numpy(),\n",
" input_grads[:, 1].detach().numpy(),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -119,14 +92,14 @@
"outputs": [],
"source": [
"\n",
"\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
@@ -143,9 +116,7 @@
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"data = dict(input_data=[d])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -167,6 +138,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0bee4d7f",
"metadata": {},
"outputs": [],
"source": [
@@ -220,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
@@ -441,9 +413,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -9,7 +9,9 @@ class MyModel(nn.Module):
super(MyModel, self).__init__()
def forward(self, w, x, y, z):
return [((x & y)) == (x & (y | (z ^ w)))]
a = (x & y)
b = (y & (z ^ w))
return [a & b]
circuit = MyModel()

View File

@@ -1 +1 @@
{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]}
{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]}

View File

@@ -1,21 +1,17 @@
pytorch1.12.1:«
+
pytorch2.2.2:„
*
input1
input2
onnx::Equal_4And_0"And
input2
/And_output_0/And"And
)
input3
input
onnx::Or_5Xor_1"Xor
input3
input
/Xor_output_0/Xor"Xor
input2
onnx::Or_5 onnx::And_6Or_2"Or
0
input1
onnx::And_6
onnx::Equal_7And_3"And
6
5
input2
/Xor_output_0/And_1_output_0/And_1"And
5
/And_output_0
/And_1_output_0output/And_2"And

File diff suppressed because one or more lines are too long

Binary file not shown.

42
examples/onnx/log/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.log(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 3)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}

View File

@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Log"Log
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -21,9 +21,9 @@ def main():
torch_model = Circuit()
# Input to the model
shape = [3, 2, 3]
w = 0.1*torch.rand(1, *shape, requires_grad=True)
x = 0.1*torch.rand(1, *shape, requires_grad=True)
y = 0.1*torch.rand(1, *shape, requires_grad=True)
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
torch_out = torch_model(w, x, y)
# Export the model
torch.onnx.export(torch_model, # model being run

View File

@@ -1 +1,148 @@
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
{
"input_shapes": [
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
]
],
"input_data": [
[
0.5,
1.5,
-0.04514765739440918,
0.5936200618743896,
0.9271858930587769,
0.6688600778579712,
-0.20331168174743652,
-0.7016235589981079,
0.025863051414489746,
-0.19426143169403076,
0.9827852249145508,
0.4897397756576538,
-1.5,
-0.5,
0.9278832674026489,
0.5943725109100342,
-0.573331356048584,
0.3675816059112549
],
[
0.7803324460983276,
-0.9616303443908691,
0.6070173978805542,
-0.028337717056274414,
-0.5080242156982422,
-0.9280107021331787,
0.6150380373001099,
0.3865993022918701,
-0.43668973445892334,
0.17152702808380127,
0.5144252777099609,
-0.28881049156188965,
0.8932310342788696,
0.059034109115600586,
0.6865451335906982,
0.009820222854614258,
0.23011493682861328,
-0.9492779970169067
],
[
-0.21352827548980713,
-0.16015326976776123,
-0.38964390754699707,
0.13464701175689697,
-0.8814496994018555,
0.5037975311279297,
-0.804405927658081,
0.9858957529067993,
0.19567716121673584,
0.9777265787124634,
0.6151977777481079,
0.568595290184021,
0.10584986209869385,
-0.8975653648376465,
0.6235959529876709,
-0.547879695892334,
0.9289869070053101,
0.7567293643951416
]
],
"output_data": [
[
1.0,
0.0,
-0.0,
1.0,
1.0,
1.0,
-0.0,
-1.0,
0.0,
-0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
-1.0,
0.0
],
[
0.0,
-1.0,
0.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-1.0
],
[
-0.0,
-0.0,
-0.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0
]
]
}

View File

@@ -1,10 +1,11 @@
pytorch2.0.1:â
pytorch2.2.2:ă

woutput_w/Round"Round

xoutput_x/Floor"Floor

youtput_y/Ceil"Ceil torch_jitZ%
youtput_y/Ceil"Ceil
main_graphZ%
w



View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
# reciprocal sqrt
m = 1 / torch.sqrt(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}

View File

@@ -0,0 +1,17 @@
pytorch2.2.2:Ź
$
input/Sqrt_output_0/Sqrt"Sqrt
1
/Sqrt_output_0output /Reciprocal"
Reciprocal
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -0,0 +1 @@
network.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,47 @@
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
cp input.json calibration.json
ezkl calibrate-settings
ezkl compile-circuit
ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
size 82095882

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -9,7 +9,6 @@ import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
import { error } from 'console'
async function deployContract(
vm: VM,
@@ -66,7 +65,7 @@ async function verify(
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
error('vkAddress', vkAddress)
console.error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)

View File

@@ -99,6 +99,10 @@ fi
echo "Removing old ezkl binary if it exists"
[ -e file ] && rm file
# echo platform and architecture
echo "Platform: $PLATFORM"
echo "Architecture: $ARCHITECTURE"
# download the release and unpack the right tarball
if [ "$PLATFORM" == "windows-msvc" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-02-06"
channel = "nightly-2024-07-18"
components = ["rustfmt", "clippy"]

View File

@@ -1,33 +1,33 @@
// ignore file if compiling for wasm
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored_json::ToColoredJson;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::commands::Cli;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::{error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(feature = "icicle")]
use std::env;
#[cfg(not(target_arch = "wasm32"))]
use std::error::Error;
#[tokio::main(flavor = "current_thread")]
#[cfg(not(target_arch = "wasm32"))]
pub async fn main() -> Result<(), Box<dyn Error>> {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
let args = Cli::parse();
if let Some(generator) = args.generator {
ezkl::commands::print_completions(generator, &mut Cli::command());
Ok(())
} else if let Some(command) = args.command {
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
@@ -38,19 +38,28 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
} else {
info!("Running with CPU");
}
info!("command: \n {}", &command.as_json().to_colored_json_auto()?);
info!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);
let res = run(command).await;
match &res {
Ok(_) => info!("succeeded"),
Err(e) => error!("failed: {}", e),
};
res.map(|_| ())
Ok(_) => {
info!("succeeded");
}
Err(e) => {
error!("{}", e);
std::process::exit(1)
}
}
} else {
Err("No command provided".into())
init_logger();
error!("No command provided");
std::process::exit(1)
}
}
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub fn main() {}
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]

269
src/bin/ios_gen_bindings.rs Normal file
View File

@@ -0,0 +1,269 @@
use camino::Utf8Path;
use std::fs;
use std::fs::remove_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
use uniffi_bindgen::bindings::SwiftBindingGenerator;
use uniffi_bindgen::library_mode::generate_bindings;
use uuid::Uuid;
fn main() {
let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set");
let mode = determine_build_mode();
build_bindings(&library_name, mode);
}
/// Determines the build mode based on the CONFIGURATION environment variable.
/// Defaults to "release" if not set or unrecognized.
/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster.
fn determine_build_mode() -> &'static str {
match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) {
Ok(ref config) if config == "debug" => "debug",
_ => "release",
}
}
/// Builds the Swift bindings and XCFramework for the specified library and build mode.
fn build_bindings(library_name: &str, mode: &str) {
// Get the root directory of this Cargo project
let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| std::env::current_dir().unwrap());
// Define the build directory inside the manifest directory
let build_dir = manifest_dir.join("build");
// Create a temporary directory to store the bindings and combined library
let tmp_dir = mktemp_local(&build_dir);
// Define directories for Swift bindings and output bindings
let swift_bindings_dir = tmp_dir.join("SwiftBindings");
let bindings_out = create_bindings_out_dir(&tmp_dir);
let framework_out = bindings_out.join("EzklCore.xcframework");
// Define target architectures for building
// We currently only support iOS devices and simulators running on ARM Macs
// This is due to limiting the library size to under 100MB for GitHub Commit Size Limit
// To support older Macs (Intel), follow the instructions in the comments below
#[allow(clippy::useless_vec)]
let target_archs = vec![
vec!["aarch64-apple-ios"], // iOS device
vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac
// vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel)
];
// Build the library for each architecture and combine them
let out_lib_paths: Vec<PathBuf> = target_archs
.iter()
.map(|archs| build_combined_archs(library_name, archs, &build_dir, mode))
.collect();
// Generate the path to the built dynamic library (.dylib)
let out_dylib_path = build_dir.join(format!(
"{}/{}/lib{}.dylib",
target_archs[0][0], mode, library_name
));
// Generate Swift bindings using uniffi_bindgen
generate_ios_bindings(&out_dylib_path, &swift_bindings_dir)
.expect("Failed to generate iOS bindings");
// Move the generated Swift file to the bindings output directory
fs::rename(
swift_bindings_dir.join(format!("{}.swift", library_name)),
bindings_out.join("EzklCore.swift"),
)
.expect("Failed to copy swift bindings file");
// Rename the `ios_ezklFFI.modulemap` file to `module.modulemap`
fs::rename(
swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)),
swift_bindings_dir.join("module.modulemap"),
)
.expect("Failed to rename modulemap file");
// Create the XCFramework from the combined libraries and Swift bindings
create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out);
// Define the destination directory for the bindings
let bindings_dest = build_dir.join("EzklCoreBindings");
if bindings_dest.exists() {
fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory");
}
// Move the bindings output to the destination directory
fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place");
// Clean up temporary directories
cleanup_temp_dirs(&build_dir);
}
/// Creates the output directory for the bindings.
/// Returns the path to the bindings output directory.
fn create_bindings_out_dir(base_dir: &Path) -> PathBuf {
let bindings_out = base_dir.join("EzklCoreBindings");
fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory");
bindings_out
}
/// Builds the library for each architecture and combines them into a single library using lipo.
/// Returns the path to the combined library.
fn build_combined_archs(
library_name: &str,
archs: &[&str],
build_dir: &Path,
mode: &str,
) -> PathBuf {
// Build the library for each architecture
let out_lib_paths: Vec<PathBuf> = archs
.iter()
.map(|&arch| {
build_for_arch(arch, build_dir, mode);
build_dir
.join(arch)
.join(mode)
.join(format!("lib{}.a", library_name))
})
.collect();
// Create a unique temporary directory for the combined library
let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name));
// Combine the libraries using lipo
let mut lipo_cmd = Command::new("lipo");
lipo_cmd
.arg("-create")
.arg("-output")
.arg(lib_out.to_str().unwrap());
for lib_path in &out_lib_paths {
lipo_cmd.arg(lib_path.to_str().unwrap());
}
let status = lipo_cmd.status().expect("Failed to run lipo command");
if !status.success() {
panic!("lipo command failed with status: {}", status);
}
lib_out
}
/// Builds the library for a specific architecture.
fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) {
// Ensure the target architecture is installed
install_arch(arch);
// Run cargo build for the specified architecture and mode
let mut build_cmd = Command::new("cargo");
build_cmd
.arg("build")
.arg("--no-default-features")
.arg("--features")
.arg("ios-bindings");
if mode == "release" {
build_cmd.arg("--release");
}
build_cmd
.arg("--lib")
.env("CARGO_BUILD_TARGET_DIR", build_dir)
.env("CARGO_BUILD_TARGET", arch);
let status = build_cmd.status().expect("Failed to run cargo build");
if !status.success() {
panic!("cargo build failed for architecture: {}", arch);
}
}
/// Installs the specified target architecture using rustup.
fn install_arch(arch: &str) {
let status = Command::new("rustup")
.arg("target")
.arg("add")
.arg(arch)
.status()
.expect("Failed to run rustup command");
if !status.success() {
panic!("Failed to install target architecture: {}", arch);
}
}
/// Generates Swift bindings for the iOS library using uniffi_bindgen.
fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> {
// Remove existing binding directory if it exists
if binding_dir.exists() {
remove_dir_all(binding_dir)?;
}
// Generate the Swift bindings using uniffi_bindgen
generate_bindings(
Utf8Path::from_path(dylib_path).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path")
})?,
None,
&SwiftBindingGenerator,
None,
Utf8Path::from_path(binding_dir).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid Swift bindings directory",
)
})?,
true,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
Ok(())
}
/// Creates an XCFramework from the combined libraries and Swift bindings.
fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) {
let mut xcbuild_cmd = Command::new("xcodebuild");
xcbuild_cmd.arg("-create-xcframework");
// Add each library and its corresponding headers to the xcodebuild command
for lib_path in lib_paths {
println!("Including library: {:?}", lib_path);
xcbuild_cmd.arg("-library");
xcbuild_cmd.arg(lib_path.to_str().unwrap());
xcbuild_cmd.arg("-headers");
xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap());
}
xcbuild_cmd.arg("-output");
xcbuild_cmd.arg(framework_out.to_str().unwrap());
let status = xcbuild_cmd.status().expect("Failed to run xcodebuild");
if !status.success() {
panic!("xcodebuild failed with status: {}", status);
}
}
/// Creates a temporary directory inside the build path with a unique UUID.
/// This ensures unique build artifacts for concurrent builds.
fn mktemp_local(build_path: &Path) -> PathBuf {
let dir = tmp_local(build_path).join(Uuid::new_v4().to_string());
fs::create_dir(&dir).expect("Failed to create temporary directory");
dir
}
/// Gets the path to the local temporary directory inside the build path.
fn tmp_local(build_path: &Path) -> PathBuf {
let tmp_path = build_path.join("tmp");
if let Ok(metadata) = fs::metadata(&tmp_path) {
if !metadata.is_dir() {
panic!("Expected 'tmp' to be a directory");
}
} else {
fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory");
}
tmp_path
}
/// Cleans up temporary directories inside the build path.
fn cleanup_temp_dirs(build_dir: &Path) {
let tmp_dir = build_dir.join("tmp");
if tmp_dir.exists() {
fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories");
}
}

12
src/bindings/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
/// Python bindings
#[cfg(feature = "python-bindings")]
pub mod python;
/// Universal bindings for all platforms
#[cfg(any(
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown")
))]
pub mod universal;
/// wasm prover and verifier
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;

View File

@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
use crate::circuit::modules::Module;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_i64, i64_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::TestDataSource;
use crate::graph::{
@@ -180,9 +180,6 @@ struct PyRunArgs {
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
@@ -191,6 +188,15 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
/// int: The base used for decomposition
#[pyo3(get, set)]
pub decomp_base: usize,
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
}
/// default instantiation of PyRunArgs
@@ -206,6 +212,7 @@ impl PyRunArgs {
impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
@@ -217,10 +224,11 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
}
}
}
@@ -228,6 +236,7 @@ impl From<PyRunArgs> for RunArgs {
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
@@ -239,10 +248,11 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
}
}
}
@@ -331,9 +341,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
#[pyfunction(signature = (
felt,
))]
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i64(felt);
let int_rep = felt_to_integer_rep(felt);
Ok(int_rep)
}
@@ -357,7 +367,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
))]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i64(felt);
let int_rep = felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
@@ -385,7 +395,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = i64_to_felt(int_rep);
let felt = integer_rep_to_felt(int_rep);
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}
@@ -863,8 +873,6 @@ fn gen_settings(
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
@@ -879,7 +887,6 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
py: Python,
@@ -887,11 +894,10 @@ fn calibrate_settings(
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: i64,
lookup_safety_margin: f64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::calibrate(
@@ -902,7 +908,6 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.await
@@ -1490,8 +1495,8 @@ fn encode_evm_calldata<'a>(
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
/// reusable: bool
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
///
/// Returns
/// -------
@@ -1503,7 +1508,7 @@ fn encode_evm_calldata<'a>(
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
fn create_evm_verifier(
py: Python,
@@ -1512,7 +1517,7 @@ fn create_evm_verifier(
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_verifier(
@@ -1521,7 +1526,7 @@ fn create_evm_verifier(
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
reusable,
)
.await
.map_err(|e| {
@@ -1533,6 +1538,57 @@ fn create_evm_verifier(
})
}
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifying key.
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
fn create_evm_vka(
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
@@ -1558,6 +1614,7 @@ fn create_evm_verifier(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
witness_path=None,
))]
fn create_evm_data_attestation(
py: Python,
@@ -1565,6 +1622,7 @@ fn create_evm_data_attestation(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
witness_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_data_attestation(
@@ -1572,6 +1630,7 @@ fn create_evm_data_attestation(
sol_code_path,
abi_path,
input_data,
witness_path,
)
.await
.map_err(|e| {
@@ -1650,6 +1709,7 @@ fn setup_test_evm_witness(
addr_path,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
@@ -1658,6 +1718,7 @@ fn deploy_evm(
addr_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -1668,42 +1729,7 @@ fn deploy_evm(
addr_path,
optimizer_runs,
private_key,
"Halo2Verifier",
)
.await
.map_err(|e| {
let err_str = format!("Failed to run deploy_evm: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// deploys the solidity vk verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
fn deploy_vk_evm(
py: Python,
addr_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2VerifyingKey",
contract_type,
)
.await
.map_err(|e| {
@@ -1759,7 +1785,7 @@ fn deploy_da_evm(
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
/// The verifier contract's address as a hex string
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
@@ -1771,7 +1797,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool
@@ -1839,8 +1865,8 @@ fn verify_evm<'a>(
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
/// reusable: bool
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
///
/// Returns
/// -------
@@ -1853,7 +1879,7 @@ fn verify_evm<'a>(
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
fn create_evm_verifier_aggr(
py: Python,
@@ -1863,7 +1889,7 @@ fn create_evm_verifier_aggr(
abi_path: PathBuf,
logrows: u32,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_aggregate_verifier(
@@ -1873,7 +1899,7 @@ fn create_evm_verifier_aggr(
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
reusable,
)
.await
.map_err(|e| {
@@ -1922,8 +1948,8 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;

579
src/bindings/universal.rs Normal file
View File

@@ -0,0 +1,579 @@
use halo2_proofs::{
plonk::*,
poly::{
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
multiopen::{ProverIPA, VerifierIPA},
strategy::SingleStrategy as IPASingleStrategy,
},
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
};
use std::fmt::Display;
use std::io::BufReader;
use std::str::FromStr;
use crate::{
circuit::region::RegionSettings,
graph::GraphSettings,
pfsys::{
create_proof_circuit,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
},
tensor::TensorType,
CheckMode, Commitments, EZKLError as InnerEZKLError,
};
use crate::graph::{GraphCircuit, GraphWitness};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::{FromUniformBytes, PrimeField},
};
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
/// Wrapper around the Error Message
#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))]
#[derive(Debug)]
pub enum EZKLError {
/// Some Comment
InternalError(String),
}
impl Display for EZKLError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EZKLError::InternalError(e) => write!(f, "Internal error: {}", e),
}
}
}
impl From<InnerEZKLError> for EZKLError {
fn from(e: InnerEZKLError) -> Self {
EZKLError::InternalError(e.to_string())
}
}
/// Encode verifier calldata from proof and ethereum vk_address
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn encode_verifier_calldata(
// TODO - shuold it be pub(crate) or pub or pub(super)?
proof: Vec<u8>,
vk_address: Option<Vec<u8>>,
) -> Result<Vec<u8>, EZKLError> {
let snark: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address {
let array: [u8; 20] =
serde_json::from_slice(&vk_address[..]).map_err(InnerEZKLError::from)?;
Some(array)
} else {
None
};
let flattened_instances = snark.instances.into_iter().flatten();
let encoded = encode_calldata(
vk_address,
&snark.proof,
&flattened_instances.collect::<Vec<_>>(),
);
Ok(encoded)
}
/// Generate witness from compiled circuit and input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
})?;
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
let mut input = circuit
.load_graph_input(&input)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
None,
None,
RegionSettings::all_true(
circuit.settings().run_args.decomp_base,
circuit.settings().run_args.decomp_legs,
),
)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
serde_json::to_vec(&witness)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
}
/// Generate verifying key from compiled circuit, and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_vk(
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
compress_selectors: bool,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let vk = create_vk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| {
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
})?;
Ok(serialized_vk)
}
/// Generate proving key from vk, compiled circuit and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_pk(
vk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
let pk = create_pk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk, &circuit, &params)
.map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?;
let mut serialized_pk = Vec::new();
pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?;
Ok(serialized_pk)
}
/// Verify proof with vk, proof json, circuit settings json and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn verify(
proof: Vec<u8>,
vk: Vec<u8>,
settings: Vec<u8>,
srs: Vec<u8>,
) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?;
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = BufReader::new(&srs[..]);
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> = get_params(&mut reader)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(EZKLError::InternalError(format!(
"Verification failed: {}",
e
))),
}
}
/// Verify aggregate proof with vk, proof, circuit settings and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn verify_aggr(
proof: Vec<u8>,
vk: Vec<u8>,
logrows: u64,
srs: Vec<u8>,
commitment: &str,
) -> Result<bool, EZKLError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment)
.map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?;
let orig_n = 1 << logrows;
let mut reader = BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)),
)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
result
.map(|_| true)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))
}
/// Prove in browser with compiled circuit, witness json, proving key, and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn prove(
witness: Vec<u8>,
pk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
#[cfg(feature = "det-prove")]
log::set_max_level(log::LevelFilter::Debug);
#[cfg(not(feature = "det-prove"))]
log::set_max_level(log::LevelFilter::Info);
let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
circuit
.load_graph_witness(&data)
.map_err(InnerEZKLError::from)?;
let public_inputs = circuit
.prepare_public_inputs(&data)
.map_err(InnerEZKLError::from)?;
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
let mut reader = BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
KZGCommitmentScheme<Bn256>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
KZGSingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::KZG,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
IPACommitmentScheme<G1Affine>,
_,
ProverIPA<_>,
VerifierIPA<_>,
IPASingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::IPA,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
}
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?)
}
/// Validate the witness json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the compiled circuit
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
})?;
Ok(true)
}
/// Validate the input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::graph::input::GraphData =
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the proof json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the verifying key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
Ok(true)
}
/// Validate the proving key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
Ok(true)
}
/// Validate the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let _: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize params: {}", e))
})?;
Ok(true)
}
// HELPER FUNCTIONS
fn get_params<
Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>,
>(
mut reader: &mut BufReader<&[u8]>,
) -> Result<Scheme, EZKLError> {
halo2_proofs::poly::commitment::Params::<G1Affine>::read(&mut reader)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)))
}
/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_vk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the verifying key
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
Ok(vk)
}
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_pk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
vk: VerifyingKey<Scheme::Curve>,
circuit: &C,
params: &'_ Scheme::ParamsProver,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the proving key
let pk = keygen_pk(params, vk, &empty_circuit)?;
Ok(pk)
}

372
src/bindings/wasm.rs Normal file
View File

@@ -0,0 +1,372 @@
use crate::{
circuit::modules::{
polycommit::PolyCommitChip,
poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
},
Module,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
GraphSettings,
},
};
use console_error_panic_hook;
use halo2_proofs::{
plonk::*,
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
};
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
};
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use crate::bindings::universal::{
compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness,
input_validation, pk_validation, proof_validation, settings_validation, srs_validation,
verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError,
};
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
impl From<ExternalEZKLError> for JsError {
fn from(e: ExternalEZKLError) -> Self {
JsError::new(&format!("{}", e))
}
}
#[wasm_bindgen]
/// Initialize logger for wasm
pub fn init_logger() {
log::set_logger(&DEFAULT_LOGGER).unwrap();
}
#[wasm_bindgen]
/// Initialize panic hook for wasm
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
/// Wrapper around the halo2 encode call data method
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn encodeVerifierCalldata(
proof: wasm_bindgen::Clamped<Vec<u8>>,
vk_address: Option<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&felt_to_integer_rep(felt))
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
))
}
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let int_rep = felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToFelt(
input: f64,
scale: crate::Scale,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Generate a kzg commitment.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn kzgCommit(
message: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let mut reader = std::io::BufReader::new(&params_ser[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&params,
);
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice
let buffer: &[u8] = &buffer;
// Divide the buffer into chunks of 64 bytes
let chunks = buffer.chunks_exact(16);
// Get the remainder
let remainder = chunks.remainder();
// Add 0s to the remainder to make it 64 bytes
let mut remainder = remainder.to_vec();
// Collect chunks into a Vec<[u8; 16]>.
let chunks: Result<Vec<[u8; 16]>, JsError> = chunks
.map(|slice| {
let array: [u8; 16] = slice
.try_into()
.map_err(|_| JsError::new("failed to slice input chunks"))?;
Ok(array)
})
.collect();
let mut chunks = chunks?;
if remainder.len() != 0 {
remainder.resize(16, 0);
// Convert the Vec<u8> to [u8; 16]
let remainder_array: [u8; 16] = remainder
.try_into()
.map_err(|_| JsError::new("failed to slice remainder"))?;
// append the remainder to the chunks
chunks.push(remainder_array);
}
// Convert each chunk to a field element
let field_elements: Vec<Fr> = chunks
.iter()
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&field_elements)
.map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?,
))
}
/// Generate a poseidon hash in browser. Input message
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn poseidonHash(
message: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)),
)?))
}
/// Generate a witness file from input.json, compiled model and a settings.json file.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genWitness(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
input: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_witness(compiled_circuit.0, input.0).map_err(JsError::from)
}
/// Generate verifying key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genVk(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
compress_selectors: bool,
) -> Result<Vec<u8>, JsError> {
gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from)
}
/// Generate proving key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genPk(
vk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from)
}
/// Verify proof in browser using wasm
#[wasm_bindgen]
pub fn verify(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
}
/// Verify aggregate proof in browser using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from)
}
/// Prove in browser using wasm
#[wasm_bindgen]
pub fn prove(
witness: wasm_bindgen::Clamped<Vec<u8>>,
pk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from)
}
// VALIDATION FUNCTIONS
/// Witness file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn witnessValidation(witness: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
witness_validation(witness.0).map_err(JsError::from)
}
/// Compiled circuit validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn compiledCircuitValidation(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from)
}
/// Input file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn inputValidation(input: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
input_validation(input.0).map_err(JsError::from)
}
/// Proof file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn proofValidation(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
proof_validation(proof.0).map_err(JsError::from)
}
/// Vk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vkValidation(
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
vk_validation(vk.0, settings.0).map_err(JsError::from)
}
/// Pk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn pkValidation(
pk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
pk_validation(pk.0, settings.0).map_err(JsError::from)
}
/// Settings file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn settingsValidation(settings: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
settings_validation(settings.0).map_err(JsError::from)
}
/// Srs file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
srs_validation(srs.0).map_err(JsError::from)
}
/// HELPER FUNCTIONS
pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
for &b in arr.iter().rev() {
n <<= 8;
n |= b as u128;
}
n
}

View File

@@ -0,0 +1,25 @@
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum ModuleError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Wrong input type for a module
#[error("wrong input type {0} must be {1}")]
WrongInputType(String, String),
/// A constant was not previously assigned
#[error("constant was not previously assigned")]
ConstantNotAssigned,
/// Input length is wrong
#[error("input length is wrong {0}")]
InputWrongLength(usize),
}
impl From<ModuleError> for PlonkError {
fn from(_e: ModuleError) -> PlonkError {
PlonkError::Synthesis
}
}

View File

@@ -6,10 +6,11 @@ pub mod polycommit;
///
pub mod planner;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Error},
};
///
pub mod errors;
use halo2_proofs::{circuit::Layouter, plonk::ConstraintSystem};
use halo2curves::ff::PrimeField;
pub use planner::*;
@@ -35,14 +36,14 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Name
fn name(&self) -> &'static str;
/// Run the operation the module represents
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>>;
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, errors::ModuleError>;
/// Layout inputs
fn layout_inputs(
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
) -> Result<Self::InputAssignments, errors::ModuleError>;
/// Layout
fn layout(
&self,
@@ -50,7 +51,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
) -> Result<ValTensor<F>, errors::ModuleError>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;
/// Number of rows used by the module

View File

@@ -18,6 +18,7 @@ use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the PolyCommit hash function
@@ -110,7 +111,7 @@ impl Module<Fp> for PolyCommitChip {
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
Ok(())
}
@@ -123,28 +124,30 @@ impl Module<Fp> for PolyCommitChip {
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
layouter
.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
.map_err(|e| e.into())
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
Ok(vec![message])
}
@@ -216,7 +219,7 @@ mod tests {
fn polycommit_chip_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
{
@@ -244,7 +247,7 @@ mod tests {
#[test]
#[ignore]
fn polycommit_chip_much_longer_input() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
let rng = rand::rngs::OsRng;

View File

@@ -21,6 +21,7 @@ use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the Poseidon hash function
@@ -174,7 +175,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
@@ -185,78 +186,82 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let res = layouter.assign_region(
|| "load message",
|mut region| {
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, Error> = match &message {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
match &message {
ValTensor::Value { inner: v, .. } => {
v.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
match value {
ValType::Value(v) => region.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v)
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
constants.insert(
*f,
ValType::AssignedConstant(res.clone(), *f),
);
Ok(res)
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
)),
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
e
);
Err(Error::Synthesis)
}
}
})
.collect(),
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect()
}
};
})
.collect()
}
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
let offset = message.len() / WIDTH + 1;
@@ -277,7 +282,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
message.len(),
start_time.elapsed()
);
res
res.map_err(|e| e.into())
}
/// L is the number of inputs to the hash function
@@ -289,7 +294,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
@@ -301,7 +306,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let mut one_iter = false;
// do the Tree dance baby
while input_cells.len() > 1 || !one_iter {
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, Error> = input_cells
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
.chunks(L)
.enumerate()
.map(|(i, block)| {
@@ -332,7 +337,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
hash
})
.collect();
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into());
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
one_iter = true;
@@ -348,7 +354,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) => v,
_ => {
log::error!("wrong input type, must be previously assigned");
return Err(Error::Synthesis);
return Err(Error::Synthesis.into());
}
};
@@ -380,7 +386,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
let mut hash_inputs = message;
let len = hash_inputs.len();
@@ -400,7 +406,11 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
block.extend(vec![Fp::ZERO; L - remainder].iter());
}
let message = block.try_into().map_err(|_| Error::Synthesis)?;
let block_len = block.len();
let message = block
.try_into()
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
Ok(halo2_gadgets::poseidon::primitives::Hash::<
_,
@@ -411,7 +421,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
>::init()
.hash(message))
})
.collect::<Result<Vec<_>, Error>>()?;
.collect::<Result<Vec<_>, ModuleError>>()?;
one_iter = true;
hash_inputs = hashes;
}
@@ -550,7 +560,7 @@ mod tests {
fn hash_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
{

View File

@@ -1,7 +1,5 @@
use std::str::FromStr;
use thiserror::Error;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
@@ -16,6 +14,7 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use crate::{
@@ -24,33 +23,13 @@ use crate::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, Op};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
/// circuit related errors.
#[derive(Debug, Error)]
pub enum CircuitError {
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
}
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -71,6 +50,7 @@ impl std::fmt::Display for CheckMode {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for CheckMode {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
@@ -110,6 +90,7 @@ impl std::fmt::Display for Tolerance {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
@@ -196,7 +177,7 @@ impl<'source> FromPyObject<'source> for Tolerance {
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
pub lookup_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub table_selectors: Vec<Selector>,
/// Inputs:
@@ -228,7 +209,7 @@ impl DynamicLookups {
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, usize), Selector>,
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
/// Inputs:
@@ -349,7 +330,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {
@@ -513,18 +494,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
return Err(CircuitError::WrongColumnType(index.name().to_string()));
}
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
if !output.is_advice() {
return Err("wrong input type for lookup output".into());
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -654,68 +635,84 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of inner columns
if tables
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
.windows(2)
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"tables inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for q in 0..tables[0].num_blocks() {
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((x, y))
.or_insert(s_lookup);
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((q, (x, y)))
.or_insert(s_lookup);
}
}
self.dynamic_lookups.table_selectors.push(s_ltable);
}
self.dynamic_lookups.table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookups.tables.is_empty() {
@@ -737,68 +734,83 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of blocks
if references
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
.windows(2)
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"references inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
let s_reference = cs.complex_selector();
for q in 0..references[0].num_blocks() {
let s_reference = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.shuffles
.input_selectors
.entry((x, y))
.or_insert(s_input);
expression
});
self.shuffles
.input_selectors
.entry((q, (x, y)))
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
}
self.shuffles.reference_selectors.push(s_reference);
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
@@ -822,12 +834,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -918,7 +930,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
}
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
@@ -939,7 +951,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
@@ -959,7 +971,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)
}
}

100
src/circuit/ops/errors.rs Normal file
View File

@@ -0,0 +1,100 @@
use std::convert::Infallible;
use crate::{fieldutils::IntegerRep, tensor::TensorError};
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum CircuitError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] TensorError),
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
#[error("failed to flush, linear coord is not aligned with the next row")]
FlushError,
/// Constrain error
#[error("constrain_equal: one of the tensors is assigned and the other is not")]
ConstrainError,
/// Failed to get lookups
#[error("failed to get lookups for op: {0}")]
GetLookupsError(String),
/// Failed to get range checks
#[error("failed to get range checks for op: {0}")]
GetRangeChecksError(String),
/// Failed to get dynamic lookup
#[error("failed to get dynamic lookup for op: {0}")]
GetDynamicLookupError(String),
/// Failed to get shuffle
#[error("failed to get shuffle for op: {0}")]
GetShuffleError(String),
/// Failed to get constants
#[error("failed to get constants for op: {0}")]
GetConstantsError(String),
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Invalid min/max lookup range
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
InvalidMinMaxRange(IntegerRep, IntegerRep),
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
/// Mismatched lookup length
#[error("mismatched lookup lengths: {0} and {1}")]
MismatchedLookupLength(usize, usize),
/// Mismatched shuffle length
#[error("mismatched shuffle lengths: {0} and {1}")]
MismatchedShuffleLength(usize, usize),
/// Mismatched lookup table lengths
#[error("mismatched lookup table lengths: {0} and {1}")]
MismatchedLookupTableLength(usize, usize),
/// Wrong column type for lookup
#[error("wrong column type for lookup: {0}")]
WrongColumnType(String),
/// Wrong column type for dynamic lookup
#[error("wrong column type for dynamic lookup: {0}")]
WrongDynamicColumnType(String),
/// Missing selectors
#[error("missing selectors for op: {0}")]
MissingSelectors(String),
/// Table lookup error
#[error("value ({0}) out of range: ({1}, {2})")]
TableOOR(IntegerRep, IntegerRep, IntegerRep),
/// Loookup not configured
#[error("lookup not configured: {0}")]
LookupNotConfigured(String),
/// Range check not configured
#[error("range check not configured: {0}")]
RangeCheckNotConfigured(String),
/// Missing layout
#[error("missing layout for op: {0}")]
MissingLayout(String),
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
}

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::i64_to_felt,
fieldutils::integer_rep_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -13,14 +13,38 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Ln {
scale: utils::F32,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
},
Floor {
scale: utils::F32,
legs: usize,
},
Round {
scale: utils::F32,
legs: usize,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
use_range_check_for_int: bool,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
@@ -45,6 +69,8 @@ pub enum HybridOp {
ReduceArgMin {
dim: usize,
},
Max,
Min,
Softmax {
input_scale: utils::F32,
output_scale: utils::F32,
@@ -71,12 +97,19 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
vec![0, 1]
}
_ => vec![],
}
}
@@ -88,21 +121,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
fn as_string(&self) -> String {
match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
HybridOp::Max => "MAX".to_string(),
HybridOp::Min => "MIN".to_string(),
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => format!(
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
input_scale, output_scale, use_range_check_for_int
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
@@ -135,10 +179,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
HybridOp::GreaterEqual => "GREATEREQUAL".into(),
HybridOp::Less => "LESS".into(),
HybridOp::LessEqual => "LESSEQUAL".into(),
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
HybridOp::LessEqual => "LESSEQUAL".to_string(),
HybridOp::Equals => "EQUALS".into(),
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim, largest } => {
@@ -155,8 +199,36 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Floor { scale, legs } => {
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Round { scale, legs } => {
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
HybridOp::SumPool {
padding,
stride,
@@ -174,42 +246,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
if input_scale.0.fract() == 0.0
&& output_scale.0.fract() == 0.0
&& *use_range_check_for_int
{
layouts::recip(
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
layouts::div(
config,
region,
values[..].try_into()?,
i64_to_felt(input_scale.0 as i64),
i64_to_felt(output_scale.0 as i64),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Recip {
input_scale: *input_scale,
output_scale: *output_scale,
},
)?
}
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
config,
region,
values[..].try_into()?,
i64_to_felt(denom.0 as i64),
integer_rep_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
@@ -287,7 +337,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
@@ -296,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Softmax {
output_scale,
input_scale,
..
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,10 @@
use super::*;
use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i64, i64_to_felt},
graph::multiplier_to_scale,
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
tensor::{self, Tensor, TensorError, TensorType},
};
use super::Op;
@@ -16,225 +14,142 @@ use halo2curves::ff::PrimeField;
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Abs,
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
ReLU,
Max {
scale: utils::F32,
a: utils::F32,
},
Min {
scale: utils::F32,
a: utils::F32,
},
Ceil {
scale: utils::F32,
},
Floor {
scale: utils::F32,
},
Round {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
Rsqrt {
scale: utils::F32,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Ln {
scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Cos {
scale: utils::F32,
},
ACos {
scale: utils::F32,
},
Cosh {
scale: utils::F32,
},
ACosh {
scale: utils::F32,
},
Sin {
scale: utils::F32,
},
ASin {
scale: utils::F32,
},
Sinh {
scale: utils::F32,
},
ASinh {
scale: utils::F32,
},
Tan {
scale: utils::F32,
},
ATan {
scale: utils::F32,
},
Tanh {
scale: utils::F32,
},
ATanh {
scale: utils::F32,
},
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
Sign,
KroneckerDelta,
Pow {
scale: utils::F32,
a: utils::F32,
},
HardSwish {
scale: utils::F32,
},
Div { denom: utils::F32 },
IsOdd,
PowersOfTwo { scale: utils::F32 },
Ln { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Exp { scale: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
ACosh { scale: utils::F32 },
Sin { scale: utils::F32 },
ASin { scale: utils::F32 },
Sinh { scale: utils::F32 },
ASinh { scale: utils::F32 },
Tan { scale: utils::F32 },
ATan { scale: utils::F32 },
Tanh { scale: utils::F32 },
ATanh { scale: utils::F32 },
Erf { scale: utils::F32 },
Pow { scale: utils::F32, a: utils::F32 },
HardSwish { scale: utils::F32 },
}
impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i64;
let range = range as IntegerRep;
(-range, range)
}
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
LookupOp::IsOdd => "is_odd".to_string(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
LookupOp::Sin { scale } => format!("sin_{}", scale),
LookupOp::ASin { scale } => format!("asin_{}", scale),
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
LookupOp::Tan { scale } => format!("tan_{}", scale),
LookupOp::ATan { scale } => format!("atan_{}", scale),
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i64(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())),
LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())),
LookupOp::RoundHalfToEven { scale } => Ok(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)),
LookupOp::Max { scale, a } => Ok(tensor::ops::nonlinearities::max(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Min { scale, a } => Ok(tensor::ops::nonlinearities::min(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)),
LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than(
&x,
f32::from(*a).into(),
)),
LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThanEqual { a } => Ok(
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
),
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*denom).into(),
)),
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res =
match &self {
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::PowersOfTwo { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
}
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
}
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
LookupOp::ACos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acos(&x, scale.into()))
}
LookupOp::Cosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cosh(&x, scale.into()))
}
LookupOp::ACosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acosh(&x, scale.into()))
}
LookupOp::Sin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sin(&x, scale.into()))
}
LookupOp::ASin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asin(&x, scale.into()))
}
LookupOp::Sinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sinh(&x, scale.into()))
}
LookupOp::ASinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asinh(&x, scale.into()))
}
LookupOp::Tan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tan(&x, scale.into()))
}
LookupOp::ATan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atan(&x, scale.into()))
}
LookupOp::ATanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atanh(&x, scale.into()))
}
LookupOp::Tanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tanh(&x, scale.into()))
}
LookupOp::HardSwish { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
LookupOp::LeakyReLU { slope: a } => {
Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())),
LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())),
LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())),
LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())),
LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())),
LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())),
LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())),
LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())),
LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())),
LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())),
LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())),
LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())),
LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())),
LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())),
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
LookupOp::HardSwish { scale } => {
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
let output = res.map(|x| i64_to_felt(x));
let output = res.map(|x| integer_rep_to_felt(x));
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
@@ -243,36 +158,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::Abs => "ABS".into(),
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
} => format!(
"RECIP(input_scale={}, output_scale={})",
input_scale, output_scale
),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::ReLU => "RELU".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
LookupOp::IsOdd => "IS_ODD".to_string(),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
@@ -295,7 +187,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,
region,
@@ -305,21 +197,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
}
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::Sign
| LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
| LookupOp::GreaterThanEqual { .. }
| LookupOp::LessThanEqual { .. }
| LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = inputs_scale[0];
Ok(scale)
}

View File

@@ -1,10 +1,10 @@
use std::{any::Any, error::Error};
use std::any::Any;
use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -15,6 +15,8 @@ pub mod base;
///
pub mod chip;
///
pub mod errors;
///
pub mod hybrid;
/// Layouts for specific functions (composed of base ops)
pub mod layouts;
@@ -25,14 +27,16 @@ pub mod poly;
///
pub mod region;
pub use errors::CircuitError;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Returns a string representation of the operation.
@@ -44,10 +48,10 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>>;
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>>;
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError>;
/// Do any of the inputs to this op require homogenous input scales?
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
@@ -71,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
fn as_any(&self) -> &dyn Any;
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -138,8 +142,8 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.scale)
}
@@ -156,7 +160,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
match self.datum_type {
@@ -193,8 +197,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
@@ -209,8 +213,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
Err(Box::new(super::CircuitError::UnsupportedOp))
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
@@ -220,7 +224,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -230,7 +234,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash +
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -240,7 +244,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box<dyn Error>> {
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = self.quantized_values.visibility().unwrap();
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
@@ -251,7 +255,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
self.raw_values = Tensor::new(None, &[0]).unwrap();
}
///
/// Pre-assign a value
pub fn pre_assign(&mut self, val: ValTensor<F>) {
self.pre_assigned_val = Some(val)
}
@@ -263,8 +267,7 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
+ IntoI64,
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
@@ -279,7 +282,7 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
@@ -293,7 +296,7 @@ impl<
Box::new(self.clone()) // Forward to the derive(Clone) impl
}
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.quantized_values.scale().unwrap())
}

View File

@@ -1,5 +1,8 @@
use crate::{
circuit::layouts,
circuit::{
layouts,
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
};
@@ -9,6 +12,12 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
Abs,
Sign,
LeakyReLU {
slope: utils::F32,
scale: i32,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
@@ -33,6 +42,7 @@ pub enum PolyOp {
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
},
Downsample {
axis: usize,
@@ -43,6 +53,7 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
},
Add,
Sub,
@@ -97,8 +108,7 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
+ IntoI64,
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
@@ -108,6 +118,9 @@ impl<
fn as_string(&self) -> String {
match &self {
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
PolyOp::Abs => "ABS".to_string(),
PolyOp::Sign => "SIGN".to_string(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
@@ -148,17 +161,25 @@ impl<
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
PolyOp::Conv {
stride,
padding,
group,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
group,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -179,8 +200,13 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
@@ -212,9 +238,18 @@ impl<
PolyOp::Prod { axes, .. } => {
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::Conv {
padding,
stride,
group,
} => layouts::conv(
config,
region,
values[..].try_into()?,
padding,
stride,
*group,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -261,6 +296,7 @@ impl<
padding,
output_padding,
stride,
group,
} => layouts::deconv(
config,
region,
@@ -268,6 +304,7 @@ impl<
padding,
output_padding,
stride,
*group,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -278,9 +315,10 @@ impl<
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
return Err(Box::new(TensorError::DimError(
return Err(TensorError::DimError(
"Pad operation requires a single input".to_string(),
)));
)
.into());
}
let mut input = values[0].clone();
input.pad(p.clone(), 0)?;
@@ -297,8 +335,14 @@ impl<
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
// this corresponds to the relu operation
PolyOp::LeakyReLU {
slope: F32(0.0), ..
} => in_scales[0],
// this corresponds to the leaky relu operation with a slope which induces a change in scale
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
@@ -346,6 +390,7 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
_ => in_scales[0],
};
Ok(scale)

View File

@@ -1,15 +1,17 @@
use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
fieldutils::IntegerRep,
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI64 as AtomicInt;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
@@ -19,7 +21,7 @@ use std::{
},
};
use super::lookup::LookupOp;
use super::{lookup::LookupOp, CircuitError};
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
@@ -84,43 +86,87 @@ impl ShuffleIndex {
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
/// wrap other regions
#[error("Wrapped region: {0}")]
Wrapped(String),
#[derive(Debug, Clone)]
/// Some settings for a region to differentiate it across the different phases of proof generation
pub struct RegionSettings {
/// whether we are in witness generation mode
pub witness_gen: bool,
/// whether we should check range checks for validity
pub check_range: bool,
/// base for decompositions
pub base: usize,
/// number of legs for decompositions
pub legs: usize,
}
impl From<String> for RegionError {
fn from(e: String) -> Self {
Self::Wrapped(e)
#[allow(unsafe_code)]
unsafe impl Sync for RegionSettings {}
#[allow(unsafe_code)]
unsafe impl Send for RegionSettings {}
impl RegionSettings {
/// Create a new region settings
pub fn new(witness_gen: bool, check_range: bool, base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen,
check_range,
base,
legs,
}
}
/// Create a new region settings with all true
pub fn all_true(base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen: true,
check_range: true,
base,
legs,
}
}
/// Create a new region settings with all false
pub fn all_false(base: usize, legs: usize) -> RegionSettings {
RegionSettings {
witness_gen: false,
check_range: false,
base,
legs,
}
}
}
impl From<&str> for RegionError {
fn from(e: &str) -> Self {
Self::Wrapped(e.to_string())
#[derive(Debug, Default, Clone)]
/// Region statistics
pub struct RegionStatistics {
/// the current maximum value of the lookup inputs
pub max_lookup_inputs: IntegerRep,
/// the current minimum value of the lookup inputs
pub min_lookup_inputs: IntegerRep,
/// the current maximum value of the range size
pub max_range_size: IntegerRep,
/// the current set of used lookups
pub used_lookups: HashSet<LookupOp>,
/// the current set of used range checks
pub used_range_checks: HashSet<Range>,
}
impl RegionStatistics {
/// update the statistics with another set of statistics
pub fn update(&mut self, other: &RegionStatistics) {
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
self.max_range_size = self.max_range_size.max(other.max_range_size);
self.used_lookups.extend(other.used_lookups.clone());
self.used_range_checks
.extend(other.used_range_checks.clone());
}
}
impl From<TensorError> for RegionError {
fn from(e: TensorError) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Error> for RegionError {
fn from(e: Error) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Box<dyn std::error::Error>> for RegionError {
fn from(e: Box<dyn std::error::Error>) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
#[allow(unsafe_code)]
unsafe impl Sync for RegionStatistics {}
#[allow(unsafe_code)]
unsafe impl Send for RegionStatistics {}
#[derive(Debug)]
/// A context for a region
@@ -131,22 +177,33 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i64,
min_lookup_inputs: i64,
max_range_size: i64,
witness_gen: bool,
check_lookup_range: bool,
statistics: RegionStatistics,
settings: RegionSettings,
assigned_constants: ConstantsMap<F>,
max_dynamic_input_len: usize,
}
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
/// get the region's decomposition base
pub fn base(&self) -> usize {
self.settings.base
}
/// get the region's decomposition legs
pub fn legs(&self) -> usize {
self.settings.legs
}
/// get the max dynamic input len
pub fn max_dynamic_input_len(&self) -> usize {
self.max_dynamic_input_len
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
///
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={}, max_dynamic_input_len={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
@@ -154,7 +211,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
self.shuffle_col_coord().to_string().green(),
self.max_dynamic_input_len().to_string().green()
);
}
///
@@ -172,6 +231,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.dynamic_lookup_index.index += n;
}
/// increment the max dynamic input len
pub fn update_max_dynamic_input_len(&mut self, n: usize) {
self.max_dynamic_input_len = self.max_dynamic_input_len.max(n);
}
///
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
self.dynamic_lookup_index.col_coord += n;
@@ -189,16 +253,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn witness_gen(&self) -> bool {
self.witness_gen
self.settings.witness_gen
}
///
pub fn check_lookup_range(&self) -> bool {
self.check_lookup_range
pub fn check_range(&self) -> bool {
self.settings.check_range
}
///
pub fn statistics(&self) -> &RegionStatistics {
&self.statistics
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
let linear_coord = row * num_inner_cols;
@@ -209,14 +284,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: true,
check_lookup_range: true,
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(decomp_base, decomp_legs),
assigned_constants: HashMap::new(),
max_dynamic_input_len: 0,
}
}
@@ -225,45 +296,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
region,
num_inner_cols,
linear_coord,
row,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: false,
check_lookup_range: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
witness_gen: bool,
check_lookup_range: bool,
settings: RegionSettings,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -275,14 +321,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
max_dynamic_input_len: 0,
}
}
@@ -291,8 +333,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row: usize,
linear_coord: usize,
num_inner_cols: usize,
witness_gen: bool,
check_lookup_range: bool,
settings: RegionSettings,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -302,14 +343,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
max_dynamic_input_len: 0,
}
}
@@ -317,10 +354,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn apply_in_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
if self.is_dummy() {
self.dummy_loop(output, inner_loop_function)?;
} else {
@@ -333,8 +370,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn real_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>,
) -> Result<(), RegionError> {
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>,
) -> Result<(), CircuitError> {
output
.iter_mut()
.enumerate()
@@ -342,7 +379,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
*o = inner_loop_function(i, self)?;
Ok(())
})
.collect::<Result<Vec<_>, RegionError>>()?;
.collect::<Result<Vec<_>, CircuitError>>()?;
Ok(())
}
@@ -353,115 +390,71 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
*output = output.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.settings.clone(),
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
// update the lookups
let mut statistics = statistics.lock().unwrap();
statistics.update(local_reg.statistics());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
res
})?;
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
self.statistics = Arc::try_unwrap(statistics)
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
})?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?;
Ok(())
}
@@ -470,29 +463,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min forcefully
pub fn update_max_min_lookup_inputs_force(
&mut self,
min: IntegerRep,
max: IntegerRep,
) -> Result<(), CircuitError> {
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
return Err(CircuitError::InvalidMinMaxRange(range.0, range.1));
}
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
Ok(())
}
@@ -506,14 +507,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
self.used_lookups.insert(lookup);
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
self.used_range_checks.insert(range);
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
self.statistics.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
@@ -554,27 +555,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
self.statistics.used_lookups.clone()
}
/// get used range checks
pub fn used_range_checks(&self) -> HashSet<Range> {
self.used_range_checks.clone()
self.statistics.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i64 {
self.max_lookup_inputs
pub fn max_lookup_inputs(&self) -> IntegerRep {
self.statistics.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i64 {
self.min_lookup_inputs
pub fn min_lookup_inputs(&self) -> IntegerRep {
self.statistics.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i64 {
self.max_range_size
pub fn max_range_size(&self) -> IntegerRep {
self.statistics.max_range_size
}
/// Assign a valtensor to a vartensor
@@ -582,18 +583,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
@@ -609,20 +610,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<(ValTensor<F>, usize), CircuitError> {
self.update_max_dynamic_input_len(values.len());
if let Some(region) = &self.region {
var.assign(
Ok(var.assign_exact_column(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
let flush_len = var.get_column_flush(self.combined_dynamic_shuffle_coord(), values)?;
// get the diff between the current column and the next row
Ok((values.clone(), flush_len))
}
}
@@ -631,7 +639,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<(ValTensor<F>, usize), CircuitError> {
self.assign_dynamic_lookup(var, values)
}
@@ -640,27 +648,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<&usize>,
) -> Result<ValTensor<F>, Error> {
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign_with_omissions(
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
)?)
} else {
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
@@ -707,7 +712,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// constrain equal
pub fn constrain_equal(&mut self, a: &ValTensor<F>, b: &ValTensor<F>) -> Result<(), Error> {
pub fn constrain_equal(
&mut self,
a: &ValTensor<F>,
b: &ValTensor<F>,
) -> Result<(), CircuitError> {
if let Some(region) = &self.region {
let a = a.get_inner_tensor().unwrap();
let b = b.get_inner_tensor().unwrap();
@@ -717,12 +726,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let b = b.get_prev_assigned();
// if they're both assigned, we can constrain them
if let (Some(a), Some(b)) = (&a, &b) {
region.borrow_mut().constrain_equal(a.cell(), b.cell())
region
.borrow_mut()
.constrain_equal(a.cell(), b.cell())
.map_err(|e| e.into())
} else if a.is_some() || b.is_some() {
log::error!(
"constrain_equal: one of the tensors is assigned and the other is not"
);
return Err(Error::Synthesis);
return Err(CircuitError::ConstrainError);
} else {
Ok(())
}
@@ -748,7 +757,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// flush row to the next row
pub fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub fn flush(&mut self) -> Result<(), CircuitError> {
// increment by the difference between the current linear coord and the next row
let remainder = self.linear_coord % self.num_inner_cols;
if remainder != 0 {
@@ -756,7 +765,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.increment(diff);
}
if self.linear_coord % self.num_inner_cols != 0 {
return Err("flush: linear coord is not aligned with the next row".into());
return Err(CircuitError::FlushError);
}
Ok(())
}

View File

@@ -1,4 +1,4 @@
use std::{error::Error, marker::PhantomData};
use std::marker::PhantomData;
use halo2curves::ff::PrimeField;
@@ -11,20 +11,33 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
circuit::CircuitError,
fieldutils::i64_to_felt,
tensor::{IntoI64, Tensor, TensorType},
fieldutils::{integer_rep_to_felt, IntegerRep},
tensor::{Tensor, TensorType},
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::execute::EZKL_REPO_PATH;
use crate::circuit::lookup::LookupOp;
/// The range of the lookup table.
pub type Range = (i64, i64);
pub type Range = (IntegerRep, IntegerRep);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i64 = 2;
pub const RANGE_MULTIPLIER: IntegerRep = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
lazy_static::lazy_static! {
/// an optional directory to read and write the lookup table cache
pub static ref LOOKUP_CACHE: String = format!("{}/cache", *EZKL_REPO_PATH);
}
/// The lookup table cache is disabled on wasm32 target.
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub const LOOKUP_CACHE: &str = "";
#[derive(Debug, Clone)]
///
pub struct SelectorConstructor<F: PrimeField> {
@@ -96,21 +109,22 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
i64_to_felt(chunk)
integer_rep_to_felt(chunk)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
let chunk = chunk as i64;
let chunk = chunk as IntegerRep;
// we index from 1 to prevent soundness issues
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
let first_element =
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
@@ -130,12 +144,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
}
///
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i64)) as usize + 1
(range_len / (col_size as IntegerRep)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get largest element represented by the range
pub fn largest(&self) -> IntegerRep {
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
}
fn name(&self) -> String {
format!(
"{}_{}_{}",
self.nonlinearity.as_path(),
self.range.0,
self.largest()
)
}
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -194,16 +220,59 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
&mut self,
layouter: &mut impl Layouter<F>,
preassigned_input: bool,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;
let largest = self.range.1;
let largest = self.largest();
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(integer_rep_to_felt(x)))?;
let evals = self.nonlinearity.f(&[inputs.clone()])?;
Ok((inputs, evals.output))
};
let (inputs, evals) = if !LOOKUP_CACHE.is_empty() {
let cache = std::path::Path::new(&*LOOKUP_CACHE);
let cache_path = cache.join(self.name());
let input_path = cache_path.join("inputs");
let output_path = cache_path.join("outputs");
if cache_path.exists() {
log::info!("Loading lookup table from cache: {:?}", cache_path);
let (input_cache, output_cache) =
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
(input_cache, output_cache)
} else {
log::info!(
"Generating lookup table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path).map_err(|e| {
CircuitError::TensorError(crate::tensor::TensorError::FileSaveError(
e.to_string(),
))
})?;
let (inputs, evals) = gen_table()?;
inputs.save(&input_path)?;
evals.save(&output_path)?;
(inputs, evals)
}
} else {
log::info!(
"Generating lookup table {} without cache",
self.nonlinearity.as_path()
);
gen_table()?
};
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -226,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
if !preassigned_input {
table.assign_cell(
|| format!("nl_i_col row {}", row_offset),
@@ -235,7 +305,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
)?;
}
let output = evals.output[row_offset];
let output = evals[row_offset];
table.assign_cell(
|| format!("nl_o_col row {}", row_offset),
@@ -272,12 +342,17 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// as path
pub fn as_path(&self) -> String {
format!("rangecheck_{}_{}", self.range.0, self.range.1)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i64;
let chunk = chunk as IntegerRep;
// we index from 1 to prevent soundness issues
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
}
///
@@ -293,14 +368,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
i64_to_felt(chunk)
integer_rep_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
@@ -342,15 +417,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let inputs: Tensor<F> = if !LOOKUP_CACHE.is_empty() {
let cache = std::path::Path::new(&*LOOKUP_CACHE);
let cache_path = cache.join(self.as_path());
let input_path = cache_path.join("inputs");
if cache_path.exists() {
log::info!("Loading range check table from cache: {:?}", cache_path);
Tensor::load(&input_path)?
} else {
log::info!(
"Generating range check table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path)?;
let inputs = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
inputs.save(&input_path)?;
inputs
}
} else {
log::info!("Generating range check {} without cache", self.as_path());
Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x))
};
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -8,6 +8,10 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
)))]
use ops::lookup::LookupOp;
use ops::region::RegionCtx;
use rand::rngs::OsRng;
@@ -55,7 +59,7 @@ mod matmul {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -132,7 +136,7 @@ mod matmul_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -206,7 +210,7 @@ mod matmul_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -243,7 +247,10 @@ mod matmul_col_overflow {
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
mod matmul_col_ultra_overflow_double_col {
use halo2_proofs::poly::kzg::{
@@ -256,7 +263,7 @@ mod matmul_col_ultra_overflow_double_col {
use super::*;
const K: usize = 4;
const LEN: usize = 20;
const LEN: usize = 10;
const NUM_INNER_COLS: usize = 2;
#[derive(Clone)]
@@ -290,7 +297,7 @@ mod matmul_col_ultra_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -361,7 +368,10 @@ mod matmul_col_ultra_overflow_double_col {
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::kzg::{
@@ -374,7 +384,7 @@ mod matmul_col_ultra_overflow {
use super::*;
const K: usize = 4;
const LEN: usize = 20;
const LEN: usize = 10;
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -407,7 +417,7 @@ mod matmul_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -518,7 +528,7 @@ mod dot {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -595,7 +605,7 @@ mod dot_col_overflow_triple_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 3);
let mut region = RegionCtx::new(region, 0, 3, 128, 2);
config
.layout(
&mut region,
@@ -668,7 +678,7 @@ mod dot_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -741,7 +751,7 @@ mod sum {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -811,7 +821,7 @@ mod sum_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
config
.layout(
&mut region,
@@ -880,7 +890,7 @@ mod sum_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -951,7 +961,7 @@ mod composition {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let _ = config
.layout(
&mut region,
@@ -1042,7 +1052,7 @@ mod conv {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -1050,6 +1060,7 @@ mod conv {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1143,7 +1154,10 @@ mod conv {
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::{
@@ -1158,7 +1172,7 @@ mod conv_col_ultra_overflow {
use super::*;
const K: usize = 4;
const LEN: usize = 28;
const LEN: usize = 10;
#[derive(Clone)]
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -1192,7 +1206,7 @@ mod conv_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
@@ -1200,6 +1214,7 @@ mod conv_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1283,7 +1298,10 @@ mod conv_col_ultra_overflow {
#[cfg(test)]
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::kzg::{
@@ -1295,8 +1313,8 @@ mod conv_relu_col_ultra_overflow {
use super::*;
const K: usize = 4;
const LEN: usize = 28;
const K: usize = 8;
const LEN: usize = 15;
#[derive(Clone)]
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -1315,15 +1333,23 @@ mod conv_relu_col_ultra_overflow {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
let mut base_config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU)
.configure_range_check(cs, &a, &b, (-1, 1), K)
.unwrap();
base_config
.configure_range_check(cs, &a, &b, (0, 1), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 8, false);
base_config.clone()
}
@@ -1332,12 +1358,12 @@ mod conv_relu_col_ultra_overflow {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
config.layout_range_checks(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
let output = config
.layout(
&mut region,
@@ -1345,6 +1371,7 @@ mod conv_relu_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis);
@@ -1352,7 +1379,10 @@ mod conv_relu_col_ultra_overflow {
.layout(
&mut region,
&[output.unwrap().unwrap()],
Box::new(LookupOp::ReLU),
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
@@ -1473,7 +1503,7 @@ mod add_w_shape_casting {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1489,7 +1519,7 @@ mod add_w_shape_casting {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i as u64 + 1))));
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(b)],
@@ -1540,7 +1570,7 @@ mod add {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1624,7 +1654,7 @@ mod dynamic_lookup {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
@@ -1766,7 +1796,7 @@ mod shuffle {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
@@ -1881,7 +1911,7 @@ mod add_with_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1983,7 +2013,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(
|| "model",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.base
.layout(&mut region, &inputs, Box::new(PolyOp::Add))
@@ -2089,7 +2119,7 @@ mod sub {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.map_err(|_| Error::Synthesis)
@@ -2156,7 +2186,7 @@ mod mult {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.map_err(|_| Error::Synthesis)
@@ -2223,7 +2253,7 @@ mod pow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.map_err(|_| Error::Synthesis)
@@ -2255,7 +2285,6 @@ mod matmul_relu {
const K: usize = 18;
const LEN: usize = 32;
use crate::circuit::lookup::LookupOp;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -2285,11 +2314,17 @@ mod matmul_relu {
let mut base_config =
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU)
.configure_range_check(cs, &a, &b, (-1, 1), K)
.unwrap();
base_config
.configure_range_check(cs, &a, &b, (0, 1023), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 8, false);
MyConfig { base_config }
}
@@ -2298,11 +2333,14 @@ mod matmul_relu {
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.base_config.layout_tables(&mut layouter).unwrap();
config
.base_config
.layout_range_checks(&mut layouter)
.unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
@@ -2312,7 +2350,14 @@ mod matmul_relu {
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
@@ -2351,6 +2396,8 @@ mod relu {
plonk::{Circuit, ConstraintSystem, Error},
};
const K: u32 = 8;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
@@ -2367,16 +2414,26 @@ mod relu {
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
.map(|_| VarTensor::new_advice(cs, 8, 1, 3))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let mut config = BaseConfig::default();
let mut config = BaseConfig::configure(
cs,
&[advices[0].clone(), advices[1].clone()],
&advices[2],
CheckMode::SAFE,
);
config
.configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl)
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K as usize)
.unwrap();
config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1), K as usize)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K as usize, 8, false);
config
}
@@ -2385,15 +2442,22 @@ mod relu {
mut config: Self::Config,
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
config.layout_range_checks(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.map_err(|_| Error::Synthesis)
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
Ok(config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap())
},
)
.unwrap();
@@ -2411,13 +2475,16 @@ mod relu {
input: ValTensor::from(input),
};
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
let prover = MockProver::run(K, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
mod lookup_ultra_overflow {
use super::*;
use halo2_proofs::{
@@ -2432,11 +2499,11 @@ mod lookup_ultra_overflow {
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
}
impl Circuit<F> for ReLUCircuit<F> {
impl Circuit<F> for SigmoidCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -2450,7 +2517,7 @@ mod lookup_ultra_overflow {
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
let mut config = BaseConfig::default();
@@ -2478,9 +2545,13 @@ mod lookup_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2492,13 +2563,13 @@ mod lookup_ultra_overflow {
#[test]
#[ignore]
fn relucircuit() {
fn sigmoidcircuit() {
// get some logs fam
crate::logger::init_logger();
// parameters
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
let circuit = ReLUCircuit::<F> {
let circuit = SigmoidCircuit::<F> {
input: ValTensor::from(a),
};
@@ -2508,7 +2579,7 @@ mod lookup_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
ReLUCircuit<F>,
SigmoidCircuit<F>,
>(&circuit, &params, true)
.unwrap();

View File

@@ -141,23 +141,23 @@ mod tests {
#[test]
fn f32_eq() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(std::f32::NAN));
assert!(F32(f32::NAN) == F32(f32::NAN));
assert!(F32(f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}
#[test]
fn f32_cmp() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(std::f32::NAN));
assert!(F32(f32::NAN) == F32(f32::NAN));
assert!(F32(f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}
#[test]
fn f32_hash() {
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
assert!(calculate_hash(&F32(f32::NAN)) == calculate_hash(&F32(-f32::NAN)));
}
}

View File

@@ -1,4 +1,3 @@
#[cfg(not(target_arch = "wasm32"))]
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
@@ -17,7 +16,6 @@ use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::TestDataSource;
use crate::pfsys::TranscriptType;
@@ -81,8 +79,10 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk separately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default render reusable verifier
pub const DEFAULT_RENDER_REUSABLE: &str = "false";
/// Default contract deployment type
pub const DEFAULT_CONTRACT_DEPLOYMENT_TYPE: &str = "verifier";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VK abi path
@@ -181,28 +181,85 @@ impl From<&str> for CalibrationTarget {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// Determines what type of contract (verifier, verifier/reusable, vka) should be deployed
pub enum ContractType {
/// Deploys a verifier contrat tailored to the circuit and not reusable
Verifier {
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
reusable: bool,
},
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
VerifyingKeyArtifact,
}
impl Default for ContractType {
fn default() -> Self {
ContractType::Verifier {
reusable: false,
}
}
}
impl std::fmt::Display for ContractType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_string()
},
ContractType::Verifier {
reusable: false,
} => "verifier".to_string(),
ContractType::VerifyingKeyArtifact => "vka".to_string(),
}
)
}
}
impl ToFlags for ContractType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<&str> for ContractType {
fn from(s: &str) -> Self {
match s {
"verifier" => ContractType::Verifier { reusable: false },
"verifier/reusable" => ContractType::Verifier { reusable: true },
"vka" => ContractType::VerifyingKeyArtifact,
_ => {
log::error!("Invalid value for ContractType");
log::warn!("Defaulting to verifier");
ContractType::default()
}
}
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
inner: H160,
}
#[cfg(not(target_arch = "wasm32"))]
impl From<H160Flag> for H160 {
fn from(val: H160Flag) -> H160 {
val.inner
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ToFlags for H160Flag {
fn to_flags(&self) -> Vec<String> {
vec![format!("{:#x}", self.inner)]
}
}
#[cfg(not(target_arch = "wasm32"))]
impl From<&str> for H160Flag {
fn from(s: &str) -> Self {
Self {
@@ -243,17 +300,38 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
lazy_static! {
/// The version of the ezkl library
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
"source - no compatibility guaranteed"
} else {
env!("CARGO_PKG_VERSION")
};
#[cfg(feature = "python-bindings")]
/// Converts ContractType into a PyObject (Required for ContractType to be compatible with Python)
impl IntoPy<PyObject> for ContractType {
fn into_py(self, py: Python) -> PyObject {
match self {
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_object(py)
}
ContractType::Verifier {
reusable: false,
} => "verifier".to_object(py),
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
}
}
}
#[cfg(feature = "python-bindings")]
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
impl<'source> FromPyObject<'source> for ContractType {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"verifier" => Ok(ContractType::Verifier {
reusable: false,
}),
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
"vka" => Ok(ContractType::VerifyingKeyArtifact),
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
}
}
}
/// Get the styles for the CLI
@@ -305,7 +383,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
pub struct Cli {
/// If provided, outputs the completion file for given shell
#[clap(long = "generate", value_parser)]
@@ -365,8 +443,7 @@ pub enum Commands {
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
#[cfg(not(target_arch = "wasm32"))]
CalibrateSettings {
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
@@ -379,9 +456,9 @@ pub enum Commands {
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be ceil(2^k * lookup_safety_margin). larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
lookup_safety_margin: i64,
lookup_safety_margin: f64,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
scales: Option<Vec<crate::Scale>>,
@@ -397,9 +474,6 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long, value_hint = clap::ValueHint::Other)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
only_range_check_rebase: Option<bool>,
},
/// Generates a dummy SRS
@@ -416,11 +490,10 @@ pub enum Commands {
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Gets an SRS from a circuit settings file.
/// Gets an SRS from a circuit settings file.
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
@@ -467,7 +540,7 @@ pub enum Commands {
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
@@ -494,7 +567,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
#[arg(
@@ -536,7 +609,7 @@ pub enum Commands {
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to output the verification key file to
@@ -552,8 +625,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -577,8 +649,7 @@ pub enum Commands {
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
#[cfg(not(target_arch = "wasm32"))]
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
@@ -591,8 +662,7 @@ pub enum Commands {
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Swaps the positions in the transcript that correspond to commitments
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
@@ -602,8 +672,7 @@ pub enum Commands {
witness_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Loads model, data, and creates proof
/// Loads model, data, and creates proof
Prove {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
@@ -617,7 +686,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
@@ -633,25 +702,23 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Encodes a proof into evm calldata
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the Solidity code
/// The path to save the calldata to
#[arg(long, default_value = DEFAULT_CALLDATA, value_hint = clap::ValueHint::FilePath)]
calldata_path: Option<PathBuf>,
/// The path to the verification key address (only used if the vk is rendered as a separate contract)
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -666,17 +733,14 @@ pub enum Commands {
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
reusable: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -692,8 +756,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -712,13 +775,15 @@ pub enum Commands {
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
@@ -736,11 +801,9 @@ pub enum Commands {
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
reusable: Option<bool>,
},
/// Verifies a proof, returning accept or reject
Verify {
@@ -753,7 +816,7 @@ pub enum Commands {
/// The path to the verification key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
@@ -771,7 +834,7 @@ pub enum Commands {
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
@@ -781,9 +844,8 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVerifier {
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
DeployEvm {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
@@ -799,28 +861,11 @@ pub enum Commands {
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
/// Contract type to be deployed
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
contract: ContractType,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVK {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that allows for data attestation
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -845,8 +890,7 @@ pub enum Commands {
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Verifies a proof using a local Evm executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
@@ -865,6 +909,13 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(feature = "no-update"))]
/// Updates ezkl binary to version specified (or latest if not specified)
Update {
/// The version to update to
#[arg(value_hint = clap::ValueHint::Other, short='v', long)]
version: Option<String>,
},
}

View File

@@ -1,7 +1,6 @@
use crate::graph::input::{CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::graph::DataSource;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::GraphSettings;
use crate::pfsys::evm::EvmVerificationError;
use crate::pfsys::Snark;
@@ -11,12 +10,11 @@ use alloy::core::primitives::Bytes;
use alloy::core::primitives::U256;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
use alloy::dyn_abi::abi::TokenSeq;
#[cfg(target_arch = "wasm32")]
use alloy::prelude::Wallet;
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::node_bindings::Anvil;
use alloy::primitives::{B256, I256};
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
@@ -25,10 +23,13 @@ use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
use alloy::signers::wallet::LocalWallet;
use alloy::signers::k256::ecdsa;
use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -36,7 +37,6 @@ use halo2curves::group::ff::PrimeField;
use itertools::Itertools;
use log::{debug, info, warn};
use reqwest::Client;
use std::error::Error;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
@@ -213,6 +213,57 @@ abigen!(
}
);
#[derive(Debug, thiserror::Error)]
pub enum EthError {
#[error("a transport error occurred: {0}")]
Transport(#[from] RpcError<TransportErrorKind>),
#[error("a contract error occurred: {0}")]
Contract(#[from] alloy::contract::Error),
#[error("a wallet error occurred: {0}")]
Wallet(#[from] WalletError),
#[error("failed to parse url {0}")]
UrlParse(String),
#[error("evm verification error: {0}")]
EvmVerification(#[from] EvmVerificationError),
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
PrivateKeyFormat,
#[error("failed to parse hex: {0}")]
HexParse(#[from] hex::FromHexError),
#[error("ecdsa error: {0}")]
Ecdsa(#[from] ecdsa::Error),
#[error("failed to load graph data")]
GraphData,
#[error("failed to load graph settings")]
GraphSettings,
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("Data source for either input_data or output_data must be OnChain")]
OnChainDataSource,
#[error("failed to parse signed integer: {0}")]
SignedIntegerParse(#[from] ParseSignedError),
#[error("failed to parse unsigned integer: {0}")]
UnSignedIntegerParse(#[from] ParseError),
#[error("updateAccountCalls should have failed")]
UpdateAccountCalls,
#[error("ethabi error: {0}")]
EthAbi(#[from] ethabi::Error),
#[error("conversion error: {0}")]
Conversion(#[from] std::convert::Infallible),
// Constructor arguments provided but no constructor found
#[error("constructor arguments provided but no constructor found")]
NoConstructor,
#[error("contract not found at path: {0}")]
ContractNotFound(String),
#[error("solc error: {0}")]
Solc(#[from] SolcError),
#[error("solc io error: {0}")]
SolcIo(#[from] SolcIoError),
#[error("svm error: {0}")]
Svm(String),
#[error("no contract output found")]
NoContractOutput,
}
// we have to generate these two contract differently because they are generated dynamically ! and hence the static compilation from above does not suit
const ATTESTDATA_SOL: &str = include_str!("../contracts/AttestData.sol");
@@ -231,11 +282,10 @@ pub type EthersClient = Arc<
pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
/// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used.
#[cfg(not(target_arch = "wasm32"))]
pub async fn setup_eth_backend(
rpc_url: Option<&str>,
private_key: Option<&str>,
) -> Result<(EthersClient, alloy::primitives::Address), Box<dyn Error>> {
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
// Launch anvil
let endpoint: String;
@@ -257,11 +307,8 @@ pub async fn setup_eth_backend(
let wallet: LocalWallet;
if let Some(private_key) = private_key {
debug!("using private key {}", private_key);
// Sanity checks for private_key
let private_key_format_error =
"Private key must be in hex format, 64 chars, without 0x prefix";
if private_key.len() != 64 {
return Err(private_key_format_error.into());
return Err(EthError::PrivateKeyFormat);
}
let private_key_buffer = hex::decode(private_key)?;
wallet = LocalWallet::from_slice(&private_key_buffer)?;
@@ -276,7 +323,7 @@ pub async fn setup_eth_backend(
ProviderBuilder::new()
.with_recommended_fillers()
.signer(EthereumSigner::from(wallet))
.on_http(endpoint.parse()?),
.on_http(endpoint.parse().map_err(|_| EthError::UrlParse(endpoint))?),
);
let chain_id = client.get_chain_id().await?;
@@ -292,15 +339,14 @@ pub async fn deploy_contract_via_solidity(
runs: usize,
private_key: Option<&str>,
contract_name: &str,
) -> Result<H160, Box<dyn Error>> {
) -> Result<H160, EthError> {
// anvil instance must be alive at least until the factory completes the deploy
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, contract_name, runs).await?;
let factory =
get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone(), None::<()>)?;
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client, None::<()>)?;
let contract = factory.deploy().await?;
Ok(contract)
@@ -314,12 +360,12 @@ pub async fn deploy_da_verifier_via_solidity(
rpc_url: Option<&str>,
runs: usize,
private_key: Option<&str>,
) -> Result<H160, Box<dyn Error>> {
) -> Result<H160, EthError> {
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
let input = GraphData::from_path(input)?;
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
let settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
let mut scales: Vec<u32> = vec![];
// The data that will be stored in the test contracts that will eventually be read from.
@@ -339,7 +385,7 @@ pub async fn deploy_da_verifier_via_solidity(
}
if settings.run_args.param_visibility.is_hashed() {
return Err(Box::new(EvmVerificationError::InvalidVisibility));
return Err(EvmVerificationError::InvalidVisibility.into());
}
if settings.run_args.output_visibility.is_hashed() {
@@ -397,20 +443,30 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
parse_calls_to_accounts(calls_to_accounts)?
} else {
return Err("Data source for either input_data or output_data must be OnChain".into());
// if calls to accounts is empty then we know need to check that atleast there kzg visibility in the settings file
let kzg_visibility = settings.run_args.input_visibility.is_polycommit()
|| settings.run_args.output_visibility.is_polycommit()
|| settings.run_args.param_visibility.is_polycommit();
if !kzg_visibility {
return Err(EthError::OnChainDataSource);
}
let factory =
get_sol_contract_factory::<_, ()>(abi, bytecode, runtime_bytecode, client, None)?;
let contract = factory.deploy().await?;
return Ok(contract);
};
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let factory = get_sol_contract_factory(
abi,
bytecode,
runtime_bytecode,
client.clone(),
client,
Some((
// address[] memory _contractAddresses,
DynSeqToken(
@@ -451,7 +507,7 @@ pub async fn deploy_da_verifier_via_solidity(
),
// uint8 _instanceOffset,
WordToken(U256::from(contract_instance_offset as u32).into()),
//address _admin
// address _admin
WordToken(client_address.into_word()),
)),
)?;
@@ -469,12 +525,12 @@ type ParsedCallsToAccount = (Vec<H160>, Vec<Vec<Bytes>>, Vec<Vec<U256>>);
fn parse_calls_to_accounts(
calls_to_accounts: Vec<CallsToAccount>,
) -> Result<ParsedCallsToAccount, Box<dyn Error>> {
) -> Result<ParsedCallsToAccount, EthError> {
let mut contract_addresses = vec![];
let mut call_data = vec![];
let mut decimals: Vec<Vec<U256>> = vec![];
for (i, val) in calls_to_accounts.iter().enumerate() {
let contract_address_bytes = hex::decode(val.address.clone())?;
let contract_address_bytes = hex::decode(&val.address)?;
let contract_address = H160::from_slice(&contract_address_bytes);
contract_addresses.push(contract_address);
call_data.push(vec![]);
@@ -492,8 +548,8 @@ pub async fn update_account_calls(
addr: H160,
input: PathBuf,
rpc_url: Option<&str>,
) -> Result<(), Box<dyn Error>> {
let input = GraphData::from_path(input)?;
) -> Result<(), EthError> {
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
// The data that will be stored in the test contracts that will eventually be read from.
let mut calls_to_accounts = vec![];
@@ -513,12 +569,12 @@ pub async fn update_account_calls(
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
parse_calls_to_accounts(calls_to_accounts)?
} else {
return Err("Data source for either input_data or output_data must be OnChain".into());
return Err(EthError::OnChainDataSource);
};
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
let contract = DataAttestation::new(addr, client.clone());
let contract = DataAttestation::new(addr, &client);
info!("contract_addresses: {:#?}", contract_addresses);
@@ -547,20 +603,19 @@ pub async fn update_account_calls(
{
info!("updateAccountCalls failed as expected");
} else {
return Err("updateAccountCalls should have failed".into());
return Err(EthError::UpdateAccountCalls);
}
Ok(())
}
/// Verify a proof using a Solidity verifier contract
#[cfg(not(target_arch = "wasm32"))]
pub async fn verify_proof_via_solidity(
proof: Snark<Fr, G1Affine>,
addr: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EthError> {
let flattened_instances = proof.instances.into_iter().flatten();
let encoded = encode_calldata(
@@ -579,15 +634,15 @@ pub async fn verify_proof_via_solidity(
let result = client.call(&tx).await;
if result.is_err() {
return Err(Box::new(EvmVerificationError::SolidityExecution));
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result.to_vec());
// decode return bytes value into uint8
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
if !result {
return Err(Box::new(EvmVerificationError::InvalidProof));
return Err(EvmVerificationError::InvalidProof.into());
}
let gas = client.estimate_gas(&tx).await?;
@@ -626,7 +681,7 @@ fn count_decimal_places(num: f32) -> usize {
pub async fn setup_test_contract<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
data: &[Vec<FileSourceInner>],
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), Box<dyn Error>> {
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), EthError> {
let mut decimals = vec![];
let mut scaled_by_decimals_data = vec![];
for input in &data[0] {
@@ -656,14 +711,13 @@ pub async fn setup_test_contract<M: 'static + Provider<Http<Client>, Ethereum>>(
/// Verify a proof using a Solidity DataAttestation contract.
/// Used for testing purposes.
#[cfg(not(target_arch = "wasm32"))]
pub async fn verify_proof_with_data_attestation(
proof: Snark<Fr, G1Affine>,
addr_verifier: H160,
addr_da: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
) -> Result<bool, EthError> {
use ethabi::{Function, Param, ParamType, StateMutability, Token};
let mut public_inputs: Vec<U256> = vec![];
@@ -671,7 +725,7 @@ pub async fn verify_proof_with_data_attestation(
for val in flattened_instances.clone() {
let bytes = val.to_repr();
let u = U256::from_le_slice(bytes.as_slice());
let u = U256::from_le_slice(bytes.inner().as_slice());
public_inputs.push(u);
}
@@ -728,15 +782,15 @@ pub async fn verify_proof_with_data_attestation(
);
let result = client.call(&tx).await;
if result.is_err() {
return Err(Box::new(EvmVerificationError::SolidityExecution));
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result);
// decode return bytes value into uint8
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
if !result {
return Err(Box::new(EvmVerificationError::InvalidProof));
return Err(EvmVerificationError::InvalidProof.into());
}
Ok(true)
@@ -748,8 +802,8 @@ pub async fn verify_proof_with_data_attestation(
pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
data: &[Vec<FileSourceInner>],
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
) -> Result<Vec<CallsToAccount>, EthError> {
let (contract, decimals) = setup_test_contract(client, data).await?;
// Get the encoded call data for each input
let mut calldata = vec![];
@@ -769,22 +823,21 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
}
/// Reads on-chain inputs, returning the raw encoded data returned from making all the calls in on_chain_input_data
#[cfg(not(target_arch = "wasm32"))]
pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
address: H160,
data: &Vec<CallsToAccount>,
) -> Result<(Vec<Bytes>, Vec<u8>), Box<dyn Error>> {
) -> Result<(Vec<Bytes>, Vec<u8>), EthError> {
// Iterate over all on-chain inputs
let mut fetched_inputs = vec![];
let mut decimals = vec![];
for on_chain_data in data {
// Construct the address
let contract_address_bytes = hex::decode(on_chain_data.address.clone())?;
let contract_address_bytes = hex::decode(&on_chain_data.address)?;
let contract_address = H160::from_slice(&contract_address_bytes);
for (call_data, decimal) in &on_chain_data.call_data {
let call_data_bytes = hex::decode(call_data.clone())?;
let call_data_bytes = hex::decode(call_data)?;
let input: TransactionInput = call_data_bytes.into();
let tx = TransactionRequest::default()
@@ -803,18 +856,15 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
}
///
#[cfg(not(target_arch = "wasm32"))]
pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
scales: Vec<crate::Scale>,
data: &(Vec<Bytes>, Vec<u8>),
) -> Result<Vec<Fr>, Box<dyn Error>> {
use alloy::primitives::ParseSignedError;
) -> Result<Vec<Fr>, EthError> {
let contract = QuantizeData::deploy(&client).await?;
let fetched_inputs = data.0.clone();
let decimals = data.1.clone();
let fetched_inputs = &data.0;
let decimals = &data.1;
let fetched_inputs = fetched_inputs
.iter()
@@ -870,7 +920,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
runtime_bytecode: Bytes,
client: Arc<M>,
params: Option<T>,
) -> Result<ContractFactory<M>, Box<dyn Error>> {
) -> Result<ContractFactory<M>, EthError> {
const MAX_RUNTIME_BYTECODE_SIZE: usize = 24577;
let size = runtime_bytecode.len();
debug!("runtime bytecode size: {:#?}", size);
@@ -888,9 +938,9 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
// Encode the constructor args & concatenate with the bytecode if necessary
let data: Bytes = match (abi.constructor(), params.is_none()) {
(None, false) => {
return Err("Constructor arguments provided but no constructor found".into())
return Err(EthError::NoConstructor);
}
(None, true) => bytecode.clone(),
(None, true) => bytecode,
(Some(_), _) => {
let mut data = bytecode.to_vec();
@@ -902,16 +952,15 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
}
};
Ok(CallBuilder::new_raw_deploy(client.clone(), data))
Ok(CallBuilder::new_raw_deploy(client, data))
}
/// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode
#[cfg(not(target_arch = "wasm32"))]
pub async fn get_contract_artifacts(
sol_code_path: PathBuf,
contract_name: &str,
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), Box<dyn Error>> {
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
artifacts::{output_selection::OutputSelection, Optimizer},
compilers::CompilerInput,
@@ -919,7 +968,9 @@ pub async fn get_contract_artifacts(
};
if !sol_code_path.exists() {
return Err(format!("file not found: {:#?}", sol_code_path).into());
return Err(EthError::ContractNotFound(
sol_code_path.to_string_lossy().to_string(),
));
}
let settings = SolcSettings {
@@ -946,7 +997,9 @@ pub async fn get_contract_artifacts(
Some(solc) => solc,
None => {
info!("required solc version is missing ... installing");
Solc::install(&SHANGHAI_SOLC).await?
Solc::install(&SHANGHAI_SOLC)
.await
.map_err(|e| EthError::Svm(e.to_string()))?
}
};
@@ -955,7 +1008,7 @@ pub async fn get_contract_artifacts(
let (abi, bytecode, runtime_bytecode) = match compiled.find(contract_name) {
Some(c) => c.into_parts_or_default(),
None => {
return Err("could not find contract".into());
return Err(EthError::ContractNotFound(contract_name.to_string()));
}
};
@@ -966,13 +1019,14 @@ pub async fn get_contract_artifacts(
pub fn fix_da_sol(
input_data: Option<Vec<CallsToAccount>>,
output_data: Option<Vec<CallsToAccount>>,
) -> Result<String, Box<dyn Error>> {
commitment_bytes: Option<Vec<u8>>,
) -> Result<String, EthError> {
let mut accounts_len = 0;
let mut contract = ATTESTDATA_SOL.to_string();
// fill in the quantization params and total calls
// as constants to the contract to save on gas
if let Some(input_data) = input_data {
if let Some(input_data) = &input_data {
let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum();
accounts_len = input_data.len();
contract = contract.replace(
@@ -980,7 +1034,7 @@ pub fn fix_da_sol(
&format!("uint256 constant INPUT_CALLS = {};", input_calls),
);
}
if let Some(output_data) = output_data {
if let Some(output_data) = &output_data {
let output_calls: usize = output_data.iter().map(|v| v.call_data.len()).sum();
accounts_len += output_data.len();
contract = contract.replace(
@@ -990,5 +1044,61 @@ pub fn fix_da_sol(
}
contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len));
// The case where a combination of on-chain data source + kzg commit is provided.
if commitment_bytes.is_some() && !commitment_bytes.as_ref().unwrap().is_empty() {
let commitment_bytes = commitment_bytes.as_ref().unwrap();
let hex_string = hex::encode(commitment_bytes);
contract = contract.replace(
"bytes constant COMMITMENT_KZG = hex\"\";",
&format!("bytes constant COMMITMENT_KZG = hex\"{}\";", hex_string),
);
} else {
// Remove the SwapProofCommitments inheritance and the checkKzgCommits function call if no commitment is provided
contract = contract.replace(", SwapProofCommitments", "");
contract = contract.replace(
"require(checkKzgCommits(encoded), \"Invalid KZG commitments\");",
"",
);
}
// if both input and output data is none then we will only deploy the DataAttest contract, adding in the verifyWithDataAttestation function
if input_data.is_none()
&& output_data.is_none()
&& commitment_bytes.as_ref().is_some()
&& !commitment_bytes.as_ref().unwrap().is_empty()
{
contract = contract.replace(
"contract SwapProofCommitments {",
"contract DataAttestation {",
);
// Remove everything past the end of the checkKzgCommits function
if let Some(pos) = contract.find(" } /// end checkKzgCommits") {
contract.truncate(pos);
contract.push('}');
}
// Add the Solidity function below checkKzgCommits
contract.push_str(
r#"
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}"#,
);
}
Ok(contract)
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,42 +2,21 @@ use halo2_proofs::arithmetic::Field;
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
use halo2curves::ff::PrimeField;
/// Converts an i32 to a PrimeField element.
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
if x >= 0 {
F::from(x as u64)
} else {
-F::from(x.unsigned_abs() as u64)
}
}
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an i64 to a PrimeField element.
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
-F::from_u128((-x) as u128)
}
}
/// Converts a PrimeField element to an i32.
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
if x > F::from(i32::MAX as u64) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
-(lower_32 as i32)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
lower_32 as i32
-F::from_u128(x.saturating_neg() as u128)
}
}
/// Converts a PrimeField element to an f64.
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
if x > F::from_u128(i64::MAX as u128) {
if x > F::from_u128(IntegerRep::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -51,17 +30,17 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
}
/// Converts a PrimeField element to an i64.
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
if x > F::from_u128(i64::MAX as u128) {
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
-(lower_128 as i64)
-(lower_128 as IntegerRep)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
lower_128 as i64
lower_128 as IntegerRep
}
}
@@ -73,33 +52,24 @@ mod test {
#[test]
fn test_conv() {
let res: F = i32_to_felt(-15i32);
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
let res: F = i32_to_felt(2_i32.pow(17));
let res: F = integer_rep_to_felt(2_i128.pow(17));
assert_eq!(res, F::from(131072));
let res: F = i64_to_felt(-15i64);
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
let res: F = i64_to_felt(2_i64.pow(17));
let res: F = integer_rep_to_felt(2_i128.pow(17));
assert_eq!(res, F::from(131072));
}
#[test]
fn felttoi32() {
for x in -(2i32.pow(16))..(2i32.pow(16)) {
let fieldx: F = i32_to_felt::<F>(x);
let xf: i32 = felt_to_i32::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttoi64() {
for x in -(2i64.pow(20))..(2i64.pow(20)) {
let fieldx: F = i64_to_felt::<F>(x);
let xf: i64 = felt_to_i64::<F>(fieldx);
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}

146
src/graph/errors.rs Normal file
View File

@@ -0,0 +1,146 @@
use std::convert::Infallible;
use thiserror::Error;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported datatype in graph node {0} ({1})")]
UnsupportedDataType(usize, String),
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Reading a file failed
#[error("[io] ({0}) {1}")]
ReadWriteFileError(String, String),
/// Model serialization error
#[error("failed to ser/deser model: {0}")]
ModelSerialize(#[from] bincode::Error),
/// Tract error
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tract] {0}")]
TractError(#[from] tract_onnx::prelude::TractError),
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
/// Public visibility for params is deprecated
#[error("public visibility for params is deprecated, please use `fixed` instead")]
ParamsPublicVisibility,
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Circuit error
#[error("[circuit] {0}")]
CircuitError(#[from] crate::circuit::CircuitError),
/// Halo2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
/// System time error
#[error("[system time] {0}")]
SystemTimeError(#[from] std::time::SystemTimeError),
/// Missing Batch Size
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
MissingBatchSize,
/// Tokio postgres error
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[eth] {0}")]
EthError(#[from] crate::eth::EthError),
/// Json error
#[error("[json] {0}")]
JsonError(#[from] serde_json::Error),
/// Missing instances
#[error("missing instances")]
MissingInstances,
/// Missing constants
#[error("missing constants")]
MissingConstants,
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
///
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
///
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
///
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
#[error("missing scale")]
MissingScale,
/// Extended k is too large
#[error("extended k is too large to accommodate the quotient polynomial with logrows {0}")]
ExtendedKTooLarge(u32),
/// Max lookup input is too large
#[error("lookup range {0} is too large")]
LookupRangeTooLarge(usize),
/// Max range check input is too large
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]
MissingDataSource,
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
}

View File

@@ -1,10 +1,10 @@
use super::errors::GraphError;
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i64_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
@@ -20,12 +20,12 @@ use std::io::BufReader;
use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
type Decimals = u8;
@@ -128,7 +128,7 @@ impl FileSourceInner {
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
@@ -150,7 +150,7 @@ impl FileSourceInner {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
@@ -171,7 +171,7 @@ impl OnChainSource {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Inner elements of inputs/outputs coming from postgres DB
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
@@ -189,7 +189,7 @@ pub struct PostgresSource {
pub port: String,
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Create a new PostgresSource
pub fn new(
@@ -211,9 +211,7 @@ impl PostgresSource {
}
/// Fetch data from postgres
pub async fn fetch(
&self,
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
@@ -247,9 +245,7 @@ impl PostgresSource {
}
/// Fetch data from postgres and format it as a FileSource
pub async fn fetch_and_format_as_file(
&self,
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
@@ -272,14 +268,14 @@ impl PostgresSource {
}
impl OnChainSource {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Create dummy local on-chain data to test the OnChain data source
pub async fn test_from_file_data(
data: &FileSource,
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
use crate::eth::{
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
};
@@ -363,7 +359,7 @@ pub enum DataSource {
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
OnChain(OnChainSource),
/// Postgres DB
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
@@ -423,7 +419,7 @@ impl<'de> Deserialize<'de> for DataSource {
if let Ok(t) = second_try {
return Ok(DataSource::OnChain(t));
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = third_try {
@@ -449,13 +445,13 @@ impl UnwindSafe for GraphData {}
impl GraphData {
// not wasm
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Convert the input data to tract data
pub fn to_tract_data(
&self,
shapes: &[Vec<usize>],
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
) -> Result<TVec<TValue>, GraphError> {
let mut inputs = TVec::new();
match &self.input_data {
DataSource::File(data) => {
@@ -470,10 +466,10 @@ impl GraphData {
}
}
_ => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
)))
))
}
}
Ok(inputs)
@@ -488,19 +484,26 @@ impl GraphData {
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let reader = std::fs::File::open(path)?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
reader.read_to_string(&mut buf).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, self)?;
Ok(())
}
@@ -509,7 +512,7 @@ impl GraphData {
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
) -> Result<Vec<Self>, GraphError> {
// split input data into batches
let mut batched_inputs = vec![];
@@ -522,12 +525,12 @@ impl GraphData {
input_data: DataSource::OnChain(_),
output_data: _,
} => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
)))
))
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
@@ -539,11 +542,11 @@ impl GraphData {
let input_size = shape.clone().iter().product::<usize>();
let input = &iterable[i];
if input.len() % input_size != 0 {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
)));
));
}
let mut batches = vec![];
for batch in input.chunks(input_size) {

View File

@@ -7,33 +7,38 @@ pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
pub mod vars;
#[cfg(not(target_arch = "wasm32"))]
/// errors for the graph
pub mod errors;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored_json::ToColoredJson;
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
use gag::Gag;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::CommitmentScheme;
pub use input::DataSource;
use itertools::Itertools;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
#[cfg(not(target_arch = "wasm32"))]
use self::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use self::input::OnChainSource;
use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
@@ -44,7 +49,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
@@ -58,7 +63,6 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
pub use vars::*;
@@ -66,15 +70,16 @@ pub use vars::*;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i64 = 2;
pub const RANGE_MULTIPLIER: IntegerRep = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: IntegerRep =
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
lazy_static! {
/// Max circuit area
pub static ref EZKL_MAX_CIRCUIT_AREA: Option<usize> =
@@ -85,65 +90,9 @@ lazy_static! {
};
}
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
/// This operation is unsupported
#[error("unsupported datatype in graph")]
UnsupportedDataType,
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Error when attempting to load a model
#[error("failed to load")]
ModelLoad,
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
}
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
@@ -179,11 +128,13 @@ pub struct GraphWitness {
/// Any hashes of outputs generated during the forward pass
pub processed_outputs: Option<ModuleForwardResult>,
/// max lookup input
pub max_lookup_inputs: i64,
pub max_lookup_inputs: IntegerRep,
/// max lookup input
pub min_lookup_inputs: i64,
pub min_lookup_inputs: IntegerRep,
/// max range check size
pub max_range_size: i64,
pub max_range_size: IntegerRep,
/// (optional) version of ezkl used
pub version: Option<String>,
}
impl GraphWitness {
@@ -212,6 +163,7 @@ impl GraphWitness {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
version: None,
}
}
@@ -310,30 +262,31 @@ impl GraphWitness {
}
/// Export the ezkl witness as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load {}", path.display()))?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let file = std::fs::File::open(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
@@ -435,7 +388,7 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Resu
#[cfg(feature = "python-bindings")]
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
use crate::python::PyG1Affine;
use crate::bindings::python::PyG1Affine;
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
.iter()
.map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect())
@@ -458,6 +411,8 @@ pub struct GraphSettings {
pub total_const_size: usize,
/// total dynamic column size
pub total_dynamic_col_size: usize,
/// max dynamic column input length
pub max_dynamic_input_len: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// number of shuffles
@@ -502,6 +457,18 @@ impl GraphSettings {
.ceil() as u32
}
/// Calc the number of rows required for the range checks
pub fn range_check_log_rows_with_blinding(&self) -> u32 {
let max_range = self
.required_range_checks
.iter()
.map(|x| x.1 - x.0)
.max()
.unwrap_or(0);
(max_range as f32).log2().ceil() as u32
}
fn model_constraint_logrows_with_blinding(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
@@ -523,6 +490,13 @@ impl GraphSettings {
.ceil() as u32
}
/// calculate the number of rows required for the dynamic lookup and shuffle
pub fn min_dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
(self.max_dynamic_input_len as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
@@ -595,11 +569,11 @@ impl GraphSettings {
}
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
return Err(e.into());
}
};
Ok(serialized)
@@ -695,17 +669,21 @@ impl GraphCircuit {
&self.core.model
}
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let f = std::fs::File::open(path)?;
let f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
@@ -732,6 +710,7 @@ impl std::fmt::Display for TestDataSource {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for TestDataSource {}
impl From<String> for TestDataSource {
@@ -770,10 +749,7 @@ pub struct TestOnChainData {
impl GraphCircuit {
///
pub fn new(
model: Model,
run_args: &RunArgs,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
pub fn new(model: Model, run_args: &RunArgs) -> Result<GraphCircuit, GraphError> {
// // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -820,7 +796,7 @@ impl GraphCircuit {
model: Model,
mut settings: GraphSettings,
check_mode: CheckMode,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
) -> Result<GraphCircuit, GraphError> {
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -844,35 +820,37 @@ impl GraphCircuit {
}
/// load inputs and outputs for the model
pub fn load_graph_witness(
&mut self,
data: &GraphWitness,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn load_graph_witness(&mut self, data: &GraphWitness) -> Result<(), GraphError> {
self.graph_witness = data.clone();
// load the module settings
Ok(())
}
/// Prepare the public inputs for the circuit.
pub fn prepare_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Vec<Fp>, Box<dyn std::error::Error>> {
pub fn prepare_public_inputs(&self, data: &GraphWitness) -> Result<Vec<Fp>, GraphError> {
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
let mut public_inputs: Vec<Fp> = vec![];
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
} else if let Some(processed_inputs) = &data.processed_inputs {
// we first process the inputs
if let Some(processed_inputs) = &data.processed_inputs {
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
}
// we then process the params
if let Some(processed_params) = &data.processed_params {
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
}
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
}
// if the outputs are public, we add them to the public inputs
if self.settings().run_args.output_visibility.is_public() {
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
// if the outputs are processed, we add the processed outputs to the public inputs
} else if let Some(processed_outputs) = &data.processed_outputs {
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
}
@@ -890,7 +868,7 @@ impl GraphCircuit {
pub fn pretty_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Option<PrettyElements>, Box<dyn std::error::Error>> {
) -> Result<Option<PrettyElements>, GraphError> {
// dequantize the supplied data using the provided scale.
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
@@ -921,7 +899,7 @@ impl GraphCircuit {
public_inputs.processed_outputs = elements.processed_outputs.clone();
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
debug!(
"rescaled and processed public inputs: {}",
serde_json::to_string(&public_inputs)?.to_colored_json_auto()?
@@ -931,11 +909,8 @@ impl GraphCircuit {
}
///
#[cfg(target_arch = "wasm32")]
pub fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -946,7 +921,7 @@ impl GraphCircuit {
pub fn load_graph_from_file_exclusively(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -956,16 +931,16 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => Err("Cannot use non-file data source as input for this method.".into()),
_ => unreachable!("cannot load from on-chain data"),
}
}
///
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -975,7 +950,7 @@ impl GraphCircuit {
.await
}
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
/// Process the data source for the model
fn process_data_source(
&mut self,
@@ -983,18 +958,16 @@ impl GraphCircuit {
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::OnChain(_) => {
Err("Cannot use on-chain data source as input for this method.".into())
}
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Process the data source for the model
async fn process_data_source(
&mut self,
@@ -1002,7 +975,7 @@ impl GraphCircuit {
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::OnChain(source) => {
let mut per_item_scale = vec![];
@@ -1024,13 +997,13 @@ impl GraphCircuit {
}
/// Prepare on chain test data
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn load_on_chain_data(
&mut self,
source: OnChainSource,
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
@@ -1054,7 +1027,7 @@ impl GraphCircuit {
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (((d, shape), scale), input_type) in file_data
@@ -1085,7 +1058,7 @@ impl GraphCircuit {
&mut self,
file_data: &[Vec<Fp>],
shapes: &[Vec<usize>],
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (d, shape) in file_data.iter().zip(shapes) {
@@ -1096,14 +1069,14 @@ impl GraphCircuit {
Ok(data)
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
)
}
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
@@ -1111,8 +1084,8 @@ impl GraphCircuit {
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i64,
) -> Result<u32, Box<dyn std::error::Error>> {
max_range_size: IntegerRep,
) -> Result<u32, GraphError> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
@@ -1130,10 +1103,10 @@ impl GraphCircuit {
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: i64,
max_range_size: IntegerRep,
max_logrows: Option<u32>,
lookup_safety_margin: i64,
) -> Result<(), Box<dyn std::error::Error>> {
lookup_safety_margin: f64,
) -> Result<(), GraphError> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
@@ -1142,15 +1115,22 @@ impl GraphCircuit {
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if subtraction overflows
let lookup_size =
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
// check if has overflowed max lookup input
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
return Err(GraphError::LookupRangeTooLarge(
lookup_size.unsigned_abs() as usize
));
}
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
return Err(GraphError::RangeCheckTooLarge(
max_range_size.unsigned_abs() as usize,
));
}
// These are hard lower limits, we can't overflow instances or modules constraints
@@ -1194,12 +1174,7 @@ impl GraphCircuit {
}
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
return Err(GraphError::ExtendedKTooLarge(max_logrows));
}
let logrows = max_logrows;
@@ -1226,7 +1201,7 @@ impl GraphCircuit {
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i64,
max_range_size: IntegerRep,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
@@ -1241,12 +1216,12 @@ impl GraphCircuit {
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
@@ -1255,9 +1230,9 @@ impl GraphCircuit {
Self::configure_with_params(&mut cs, settings);
// drop the gag
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
drop(_r);
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
drop(_g);
#[cfg(feature = "mv-lookup")]
@@ -1284,9 +1259,8 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
witness_gen: bool,
check_lookup: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
region_settings: RegionSettings,
) -> Result<GraphWitness, GraphError> {
let original_inputs = inputs.to_vec();
let visibility = VarVisibility::from_args(&self.settings().run_args)?;
@@ -1334,7 +1308,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
.forward(inputs, &self.settings().run_args, region_settings)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1379,6 +1353,7 @@ impl GraphCircuit {
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_size: model_results.max_range_size,
version: Some(crate::version().to_string()),
};
witness.generate_rescaled_elements(
@@ -1387,7 +1362,7 @@ impl GraphCircuit {
visibility,
);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
log::trace!(
"witness: \n {}",
&witness.as_json()?.to_colored_json_auto()?
@@ -1397,34 +1372,37 @@ impl GraphCircuit {
}
/// Create a new circuit from a set of input data and [RunArgs].
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn from_run_args(
run_args: &RunArgs,
model_path: &std::path::Path,
) -> Result<Self, Box<dyn std::error::Error>> {
) -> Result<Self, GraphError> {
let model = Model::from_run_args(run_args, model_path)?;
Self::new(model, run_args)
}
/// Create a new circuit from a set of input data and [GraphSettings].
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn from_settings(
params: &GraphSettings,
model_path: &std::path::Path,
check_mode: CheckMode,
) -> Result<Self, Box<dyn std::error::Error>> {
params.run_args.validate()?;
) -> Result<Self, GraphError> {
params
.run_args
.validate()
.map_err(GraphError::InvalidRunArgs)?;
let model = Model::from_run_args(&params.run_args, model_path)?;
Self::new_from_settings(model, params.clone(), check_mode)
}
///
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn populate_on_chain_test_data(
&mut self,
data: &mut GraphData,
test_on_chain_data: TestOnChainData,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), GraphError> {
// Set up local anvil instance for reading on-chain data
let input_scales = self.model().graph.get_input_scales();
@@ -1438,15 +1416,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.input_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
_ => {
return Err("Cannot use non file source as input for on-chain test.
Manually populate on-chain data from file source instead"
.into())
return Err(GraphError::OnChainDataSource);
}
};
// Get the flatten length of input_data
@@ -1467,19 +1443,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.output_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let output_data = match &data.output_data {
Some(DataSource::File(output_data)) => output_data,
Some(DataSource::OnChain(_)) => {
return Err(
"Cannot use on-chain data source as output for on-chain test.
Will manually populate on-chain data from file source instead"
.into(),
)
}
_ => return Err("No output data found".into()),
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
_ => return Err(GraphError::MissingDataSource),
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
@@ -1520,14 +1490,12 @@ impl CircuitSize {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}
@@ -1610,7 +1578,7 @@ impl Circuit<Fp> for GraphCircuit {
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
debug!(
"circuit size: \n {}",
circuit_size

View File

@@ -1,16 +1,18 @@
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::scale_to_multiplier;
use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::region::RegionSettings;
use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::IntegerRep;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
@@ -19,9 +21,9 @@ use crate::{
};
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::input::GraphData;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored::Colorize;
use halo2_proofs::{
circuit::{Layouter, Value},
@@ -34,30 +36,29 @@ use log::{debug, info, trace};
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::HashMap;
use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::io::Read;
use std::path::PathBuf;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tabled::Table;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::prelude::{
Framework, Graph, InferenceFact, InferenceModelExt, SymbolValues, TypedFact, TypedOp,
};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::internal::DatumType;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::ops::scan::Scan;
use unzip_n::unzip_n;
unzip_n!(pub 3);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
/// The result of a forward pass.
#[derive(Clone, Debug)]
@@ -65,11 +66,11 @@ pub struct ForwardResult {
/// The outputs of the forward pass.
pub outputs: Vec<Tensor<Fp>>,
/// The maximum value of any input to a lookup operation.
pub max_lookup_inputs: i64,
pub max_lookup_inputs: IntegerRep,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i64,
pub min_lookup_inputs: IntegerRep,
/// The max range check size
pub max_range_size: i64,
pub max_range_size: IntegerRep,
}
impl From<DummyPassRes> for ForwardResult {
@@ -102,6 +103,8 @@ pub struct DummyPassRes {
pub num_rows: usize,
/// num dynamic lookups
pub num_dynamic_lookups: usize,
/// max dynamic lookup input len
pub max_dynamic_input_len: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// num shuffles
@@ -117,11 +120,11 @@ pub struct DummyPassRes {
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i64,
pub max_lookup_inputs: IntegerRep,
/// min lookup inputs
pub min_lookup_inputs: i64,
pub min_lookup_inputs: IntegerRep,
/// min range check
pub max_range_size: i64,
pub max_range_size: IntegerRep,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -359,6 +362,14 @@ impl NodeType {
NodeType::SubGraph { .. } => SupportedOp::Unknown(Unknown),
}
}
/// check if it is a softmax
pub fn is_softmax(&self) -> bool {
match self {
NodeType::Node(n) => n.is_softmax(),
NodeType::SubGraph { .. } => false,
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
@@ -396,7 +407,7 @@ impl ParsedNodes {
}
/// Returns shapes of the computational graph's inputs
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
let mut inputs = vec![];
for input in self.inputs.iter() {
@@ -469,8 +480,8 @@ impl Model {
/// # Arguments
/// * `reader` - A reader for an Onnx file.
/// * `run_args` - [RunArgs]
#[cfg(not(target_arch = "wasm32"))]
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, Box<dyn Error>> {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, GraphError> {
let visibility = VarVisibility::from_args(run_args)?;
let graph = Self::load_onnx_model(reader, run_args, &visibility)?;
@@ -483,20 +494,28 @@ impl Model {
}
///
pub fn save(&self, path: PathBuf) -> Result<(), Box<dyn Error>> {
let f = std::fs::File::create(path)?;
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let writer = std::io::BufWriter::new(f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: PathBuf) -> Result<Self, Box<dyn Error>> {
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let mut f = std::fs::File::open(&path)?;
let metadata = fs::metadata(&path)?;
let mut f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let metadata = fs::metadata(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer)?;
f.read_exact(&mut buffer).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let result = bincode::deserialize(&buffer)?;
Ok(result)
}
@@ -506,9 +525,9 @@ impl Model {
&self,
run_args: &RunArgs,
check_mode: CheckMode,
) -> Result<GraphSettings, Box<dyn Error>> {
) -> Result<GraphSettings, GraphError> {
let instance_shapes = self.instance_shapes()?;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
debug!(
"{} {} {}",
"model has".blue(),
@@ -536,9 +555,13 @@ impl Model {
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let res = self.dummy_layout(run_args, &inputs, false, false)?;
let res = self.dummy_layout(
run_args,
&inputs,
RegionSettings::all_false(run_args.decomp_base, run_args.decomp_legs),
)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -549,6 +572,7 @@ impl Model {
num_rows: res.num_rows,
total_assignments: res.linear_coord,
required_lookups: res.lookup_ops.into_iter().collect(),
max_dynamic_input_len: res.max_dynamic_input_len,
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
@@ -561,13 +585,13 @@ impl Model {
version: env!("CARGO_PKG_VERSION").to_string(),
num_blinding_factors: None,
// unix time timestamp
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
timestamp: Some(
instant::SystemTime::now()
.duration_since(instant::SystemTime::UNIX_EPOCH)?
.as_millis(),
),
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
timestamp: None,
})
}
@@ -581,14 +605,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
witness_gen: bool,
check_lookup: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
region_settings: RegionSettings,
) -> Result<ForwardResult, GraphError> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, region_settings)?;
Ok(res.into())
}
@@ -597,19 +620,14 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `scale` - The scale to use for quantization.
/// * `public_params` - Whether to make the params public.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<TractResult, Box<dyn Error>> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
) -> Result<TractResult, GraphError> {
use tract_onnx::tract_hir::internal::GenericFactoid;
let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| {
error!("Error loading model: {}", e);
GraphError::ModelLoad
})?;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(run_args.variables.clone());
@@ -622,7 +640,7 @@ impl Model {
if matches!(x, GenericFactoid::Any) {
let batch_size = match variables.get("batch_size") {
Some(x) => x,
None => return Err("Unknown dimension batch_size in model inputs, set batch_size in variables".into()),
None => return Err(GraphError::MissingBatchSize),
};
fact.shape
.set_dim(i, tract_onnx::prelude::TDim::Val(*batch_size as i64));
@@ -644,29 +662,11 @@ impl Model {
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
let mut typed_model = model
let typed_model = model
.into_typed()?
.concretize_dims(&symbol_values)?
.into_decluttered()?;
// concretize constants
for node in typed_model.eval_order()? {
let node = typed_model.node_mut(node);
if let Some(op) = node.op_as_mut::<tract_onnx::tract_core::ops::konst::Const>() {
if op.0.datum_type() == DatumType::TDim {
// get inner value to Arc<Tensor>
let mut constant = op.0.as_ref().clone();
// Generally a shape or hyperparam
constant
.as_slice_mut::<tract_onnx::prelude::TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&symbol_values));
op.0 = constant.into_arc_tensor();
}
}
}
Ok((typed_model, symbol_values))
}
@@ -675,17 +675,17 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `scale` - The scale to use for quantization.
/// * `public_params` - Whether to make the params public.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn load_onnx_model(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
visibility: &VarVisibility,
) -> Result<ParsedNodes, Box<dyn Error>> {
) -> Result<ParsedNodes, GraphError> {
let start_time = instant::Instant::now();
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
let scales = VarScales::from_args(run_args)?;
let scales = VarScales::from_args(run_args);
let nodes = Self::nodes_from_graph(
&model,
run_args,
@@ -711,7 +711,7 @@ impl Model {
}
/// Formats nodes (including subgraphs) into tables !
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn table_nodes(&self) -> String {
let mut node_accumulator = vec![];
let mut string = String::new();
@@ -753,7 +753,7 @@ impl Model {
/// * `visibility` - Which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility].
/// * `input_scales` - The scales of the model's inputs.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn nodes_from_graph(
graph: &Graph<TypedFact, Box<dyn TypedOp>>,
run_args: &RunArgs,
@@ -762,7 +762,7 @@ impl Model {
symbol_values: &SymbolValues,
override_input_scales: Option<Vec<crate::Scale>>,
override_output_scales: Option<HashMap<usize, crate::Scale>>,
) -> Result<BTreeMap<usize, NodeType>, Box<dyn Error>> {
) -> Result<BTreeMap<usize, NodeType>, GraphError> {
use crate::graph::node_output_shapes;
let mut nodes = BTreeMap::<usize, NodeType>::new();
@@ -898,16 +898,8 @@ impl Model {
);
}
None => {
let mut n = Node::new(
n.clone(),
&mut nodes,
scales,
&run_args.param_visibility,
i,
symbol_values,
run_args.div_rebasing,
run_args.rebase_frac_zero_constants,
)?;
let mut n =
Node::new(n.clone(), &mut nodes, scales, i, symbol_values, run_args)?;
if let Some(ref scales) = override_input_scales {
if let Some(inp) = n.opkind.get_input() {
let scale = scales[input_idx];
@@ -923,20 +915,9 @@ impl Model {
if scales.contains_key(&i) {
let scale_diff = n.out_scale - scales[&i];
n.opkind = if scale_diff > 0 {
RebaseScale::rebase(
n.opkind,
scales[&i],
n.out_scale,
1,
run_args.div_rebasing,
)
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
} else {
RebaseScale::rebase_up(
n.opkind,
scales[&i],
n.out_scale,
run_args.div_rebasing,
)
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
};
n.out_scale = scales[&i];
}
@@ -950,7 +931,7 @@ impl Model {
Ok(nodes)
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Removes all nodes that are consts with 0 uses
fn remove_unused_nodes(nodes: &mut BTreeMap<usize, NodeType>) {
// remove all nodes that are consts with 0 uses now
@@ -969,21 +950,21 @@ impl Model {
});
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Run tract onnx model on sample data !
pub fn run_onnx_predictions(
run_args: &RunArgs,
model_path: &std::path::Path,
data_chunks: &[GraphData],
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Vec<Tensor<f32>>>, Box<dyn Error>> {
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
use tract_onnx::tract_core::internal::IntoArcTensor;
let (model, _) = Model::load_onnx_using_tract(
&mut std::fs::File::open(model_path)
.map_err(|_| format!("failed to load {}", model_path.display()))?,
run_args,
)?;
let mut file = std::fs::File::open(model_path).map_err(|e| {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?
@@ -1010,16 +991,12 @@ impl Model {
/// Creates a `Model` from parsed run_args
/// # Arguments
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
#[cfg(not(target_arch = "wasm32"))]
pub fn from_run_args(
run_args: &RunArgs,
model: &std::path::Path,
) -> Result<Self, Box<dyn Error>> {
Model::new(
&mut std::fs::File::open(model)
.map_err(|_| format!("failed to load {}", model.display()))?,
run_args,
)
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
let mut file = std::fs::File::open(model).map_err(|e| {
GraphError::ReadWriteFileError(model.display().to_string(), e.to_string())
})?;
Model::new(&mut file, run_args)
}
/// Configures a model for the circuit
@@ -1031,7 +1008,7 @@ impl Model {
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
) -> Result<PolyConfig<Fp>, GraphError> {
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
@@ -1093,7 +1070,7 @@ impl Model {
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
info!("model layout...");
let start_time = instant::Instant::now();
@@ -1103,7 +1080,11 @@ impl Model {
let input_shapes = self.graph.input_shapes()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
if self.visibility.input.is_public() {
let instance = vars.instance.as_ref().ok_or("no instance")?.clone();
let instance = vars
.instance
.as_ref()
.ok_or(GraphError::MissingInstances)?
.clone();
results.insert(*input_idx, vec![instance]);
vars.increment_instance_idx();
} else {
@@ -1123,7 +1104,14 @@ impl Model {
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
let mut thread_safe_region = RegionCtx::new_with_constants(
region,
0,
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
original_constants.clone(),
);
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1147,38 +1135,44 @@ impl Model {
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars.instance.as_ref().ok_or("no instance")?.clone();
let res = vars
.instance
.as_ref()
.ok_or(GraphError::MissingInstances)?
.clone();
vars.increment_instance_idx();
res
} else {
// if witnessed_outputs is of len less than i error
if witnessed_outputs.len() <= i {
return Err("you provided insufficient witness values to generate a fixed output".into());
return Err(GraphError::InsufficientWitnessValues);
}
witnessed_outputs[i].clone()
};
config.base.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
)
config
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
)
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>,_>>();
.collect::<Result<Vec<_>, GraphError>>();
res.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
}
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
let duration = start_time.elapsed();
@@ -1192,7 +1186,7 @@ impl Model {
config: &mut ModelConfig,
region: &mut RegionCtx<Fp>,
results: &mut BTreeMap<usize, Vec<ValTensor<Fp>>>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
// index over results to get original inputs
let orig_inputs: BTreeMap<usize, _> = results
.clone()
@@ -1203,7 +1197,7 @@ impl Model {
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
@@ -1237,7 +1231,10 @@ impl Model {
let res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node.opkind.get_mutable_constant().ok_or("no constant")?;
let c = node
.opkind
.get_mutable_constant()
.ok_or(GraphError::MissingConstants)?;
Some(c.quantized_values.clone().try_into()?)
} else {
config
@@ -1392,9 +1389,8 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
witness_gen: bool,
check_lookup: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
region_settings: RegionSettings,
) -> Result<DummyPassRes, GraphError> {
debug!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
@@ -1412,8 +1408,7 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region =
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, region_settings);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
@@ -1456,7 +1451,7 @@ impl Model {
trace!("dummy model layout took: {:?}", duration);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
let outputs = outputs
@@ -1470,6 +1465,7 @@ impl Model {
let res = DummyPassRes {
num_rows: region.row(),
linear_coord: region.linear_coord(),
max_dynamic_input_len: region.max_dynamic_input_len(),
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),
@@ -1549,7 +1545,7 @@ impl Model {
}
/// Shapes of the computational graph's public inputs (if any)
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
let mut instance_shapes = vec![];
if self.visibility.input.is_public() {
instance_shapes.extend(self.graph.input_shapes()?);

View File

@@ -11,6 +11,7 @@ use halo2curves::bn256::{Fr as Fp, G1Affine};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use super::errors::GraphError;
use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
@@ -295,7 +296,7 @@ impl GraphModules {
element_visibility: &Visibility,
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
) -> Result<ModuleForwardResult, GraphError> {
let mut poseidon_hash = None;
let mut polycommit = None;

View File

@@ -1,39 +1,41 @@
use super::scale_to_multiplier;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::utilities::node_output_shapes;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::Visibility;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::new_op_from_onnx;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
use serde::Deserialize;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::BTreeMap;
use std::error::Error;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tabled::Tabled;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::{
self,
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
if !v.is_empty() {
format!("{:?}", v)
@@ -42,7 +44,7 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_opkind(v: &SupportedOp) -> String {
v.as_string()
}
@@ -65,7 +67,7 @@ impl Op<Fp> for Rescaled {
format!("RESCALED INPUT ({})", self.inner.as_string())
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let in_scales = in_scales
.into_iter()
.zip(self.scale.iter())
@@ -80,11 +82,9 @@ impl Op<Fp> for Rescaled {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(Box::new(TensorError::DimMismatch(
"rescaled inputs".to_string(),
)));
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
@@ -120,11 +120,11 @@ impl RebaseScale {
global_scale: crate::Scale,
op_out_scale: crate::Scale,
scale_rebase_multiplier: u32,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
&& !inner.is_input()
&& !inner.is_identity()
{
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
@@ -136,7 +136,6 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
})
@@ -147,7 +146,6 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
})
@@ -162,7 +160,6 @@ impl RebaseScale {
inner: SupportedOp,
target_scale: crate::Scale,
op_out_scale: crate::Scale,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
@@ -175,7 +172,6 @@ impl RebaseScale {
original_scale: op.original_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
@@ -186,7 +182,6 @@ impl RebaseScale {
original_scale: op_out_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
@@ -210,7 +205,7 @@ impl Op<Fp> for RebaseScale {
)
}
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.target_scale)
}
@@ -219,11 +214,11 @@ impl Op<Fp> for RebaseScale {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or("no inner layout")?;
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
}
@@ -302,11 +297,11 @@ impl SupportedOp {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn homogenous_rescale(
&self,
in_scales: Vec<crate::Scale>,
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
@@ -326,6 +321,19 @@ impl SupportedOp {
SupportedOp::RebaseScale(op) => op,
}
}
/// check if is the identity operation
/// # Returns
/// * `true` if the operation is the identity operation
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }),
SupportedOp::Rescaled(op) => op.inner.is_identity(),
SupportedOp::RebaseScale(op) => op.inner.is_identity(),
_ => false,
}
}
}
impl From<Box<dyn Op<Fp>>> for SupportedOp {
@@ -372,7 +380,7 @@ impl Op<Fp> for SupportedOp {
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}
@@ -400,7 +408,7 @@ impl Op<Fp> for SupportedOp {
self
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
self.as_op().out_scale(in_scales)
}
}
@@ -427,7 +435,7 @@ pub struct Node {
pub num_uses: usize,
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl Tabled for Node {
const LENGTH: usize = 6;
@@ -467,18 +475,16 @@ impl Node {
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[allow(clippy::too_many_arguments)]
pub fn new(
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
other_nodes: &mut BTreeMap<usize, super::NodeType>,
scales: &VarScales,
param_visibility: &Visibility,
idx: usize,
symbol_values: &SymbolValues,
div_rebasing: bool,
rebase_frac_zero_constants: bool,
) -> Result<Self, Box<dyn Error>> {
run_args: &crate::RunArgs,
) -> Result<Self, GraphError> {
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -504,19 +510,23 @@ impl Node {
input_ids
.iter()
.map(|(i, _)| {
inputs.push(other_nodes.get(i).ok_or("input not found")?.clone());
inputs.push(
other_nodes
.get(i)
.ok_or(GraphError::MissingInput(idx))?
.clone(),
);
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let (mut opkind, deleted_indices) = new_op_from_onnx(
idx,
scales,
param_visibility,
node.clone(),
&mut inputs,
symbol_values,
rebase_frac_zero_constants,
run_args,
)?; // parses the op name
// we can only take the inputs as mutable once -- so we need to collect them first
@@ -544,10 +554,10 @@ impl Node {
let idx = inputs
.iter()
.position(|x| *idx == x.idx())
.ok_or("input not found")?;
.ok_or(GraphError::MissingInput(*idx))?;
Ok(inputs[idx].out_scales()[*outlet])
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.collect::<Result<Vec<_>, GraphError>>()?;
let homogenous_inputs = opkind.requires_homogenous_input_scales();
// automatically increases a constant's scale if it is only used once and
@@ -558,13 +568,13 @@ impl Node {
if inputs.len() > input {
let input_node = other_nodes
.get_mut(&inputs[input].idx())
.ok_or("input not found")?;
.ok_or(GraphError::MissingInput(idx))?;
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(
constant,
in_scales.clone(),
param_visibility,
&run_args.param_visibility,
input_node.num_uses(),
)?;
input_node.replace_opkind(constant.clone_dyn().into());
@@ -579,13 +589,7 @@ impl Node {
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(
opkind,
global_scale,
out_scale,
scales.rebase_multiplier,
div_rebasing,
);
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
out_scale = opkind.out_scale(in_scales)?;
@@ -607,18 +611,27 @@ impl Node {
num_uses,
})
}
/// check if it is a softmax node
pub fn is_softmax(&self) -> bool {
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
true
} else {
false
}
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn rescale_const_with_single_use(
constant: &mut Constant<Fp>,
in_scales: Vec<crate::Scale>,
param_visibility: &Visibility,
num_uses: usize,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), GraphError> {
if num_uses == 1 {
let current_scale = constant.out_scale(vec![])?;
let scale_max = in_scales.iter().max().ok_or("no scales")?;
let scale_max = in_scales.iter().max().ok_or(GraphError::MissingScale)?;
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =

View File

@@ -1,7 +1,7 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
@@ -150,7 +150,7 @@ impl Config {
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(unix)]
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
use std::error::Error;
use std::fmt::Display;
use crate::tensor::TensorType;
@@ -15,8 +14,11 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use self::errors::GraphError;
use super::*;
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
@@ -63,6 +65,7 @@ impl Display for Visibility {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Visibility {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
@@ -261,12 +264,12 @@ impl VarScales {
}
/// Place in [VarScales] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
Ok(Self {
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
params: args.param_scale,
rebase_multiplier: args.scale_rebase_multiplier,
})
}
}
}
@@ -303,15 +306,13 @@ impl Default for VarVisibility {
impl VarVisibility {
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
let input_vis = &args.input_visibility;
let params_vis = &args.param_visibility;
let output_vis = &args.output_visibility;
if params_vis.is_public() {
return Err(
"public visibility for params is deprecated, please use `fixed` instead".into(),
);
return Err(GraphError::ParamsPublicVisibility);
}
if !output_vis.is_public()
@@ -327,7 +328,7 @@ impl VarVisibility {
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
{
return Err(Box::new(GraphError::Visibility));
return Err(GraphError::Visibility);
}
Ok(Self {
input: input_vis.clone(),
@@ -442,7 +443,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup or shuffle should only have one block");
warn!("dynamic lookup has {} blocks", dynamic_lookup.num_blocks());
};
advices.push(dynamic_lookup);
}

View File

@@ -23,67 +23,156 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(buf_read_has_data_left)]
#![feature(stmt_expr_attributes)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
#[derive(thiserror::Error, Debug)]
#[allow(missing_docs)]
pub enum EZKLError {
#[error("[aggregation] {0}")]
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[eth] {0}")]
EthError(#[from] eth::EthError),
#[error("[graph] {0}")]
GraphError(#[from] graph::errors::GraphError),
#[error("[pfsys] {0}")]
PfsysError(#[from] pfsys::errors::PfsysError),
#[error("[circuit] {0}")]
CircuitError(#[from] circuit::errors::CircuitError),
#[error("[tensor] {0}")]
TensorError(#[from] tensor::errors::TensorError),
#[error("[module] {0}")]
ModuleError(#[from] circuit::modules::errors::ModuleError),
#[error("[io] {0}")]
IoError(#[from] std::io::Error),
#[error("[json] {0}")]
JsonError(#[from] serde_json::Error),
#[error("[utf8] {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[reqwest] {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("[fmt] {0}")]
FmtError(#[from] std::fmt::Error),
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
#[error("[Uncategorized] {0}")]
UncategorizedError(String),
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[execute] {0}")]
ExecutionError(#[from] execute::ExecutionError),
#[error("[srs] {0}")]
SrsError(#[from] pfsys::srs::SrsError),
}
impl From<&str> for EZKLError {
fn from(s: &str) -> Self {
EZKLError::UncategorizedError(s.to_string())
}
}
impl From<String> for EZKLError {
fn from(s: String) -> Self {
EZKLError::UncategorizedError(s)
}
}
use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::Visibility;
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
use halo2curves::bn256::{Bn256, G1Affine};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
/// The version of the ezkl library
const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Get the version of the library
pub fn version() -> &'static str {
match VERSION {
"0.0.0" => "source - no compatibility guaranteed",
_ => VERSION,
}
}
/// Bindings managment
#[cfg(any(
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown"),
feature = "python-bindings"
))]
pub mod bindings;
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
pub mod circuit;
/// CLI commands.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod commands;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
// abigen doesn't generate docs for this module
#[allow(missing_docs)]
/// Utility functions for contracts
pub mod eth;
/// Command execution
///
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod execute;
/// Utilities for converting from Halo2 Field types to integers (and vice-versa).
pub mod fieldutils;
/// Methods for loading onnx format models and automatically laying them out in
/// a Halo2 circuit.
#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", not(feature = "ezkl")))]
pub mod graph;
/// beautiful logging
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
pub mod logger;
/// Tools for proofs and verification used by cli
pub mod pfsys;
/// Python bindings
#[cfg(feature = "python-bindings")]
pub mod python;
/// srs sha hashes
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
pub mod srs_sha;
/// An implementation of multi-dimensional tensors.
pub mod tensor;
/// wasm prover and verifier
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
/// The denominator in the fixed point representation used when quantizing inputs
pub type Scale = i32;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
// Buf writer capacity
lazy_static! {
/// The capacity of the buffer used for writing to disk
@@ -95,12 +184,13 @@ lazy_static! {
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
const EZKL_KEY_FORMAT: &str = "raw-bytes";
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
const EZKL_BUF_CAPACITY: &usize = &8000;
#[derive(
@@ -153,6 +243,7 @@ impl std::fmt::Display for Commitments {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Commitments {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
@@ -175,58 +266,79 @@ impl From<String> for Commitments {
}
/// Parameters specific to a proving run
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// The denominator in the fixed point representation used when quantizing inputs
#[arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub input_scale: Scale,
/// The denominator in the fixed point representation used when quantizing parameters
#[arg(long, default_value = "7", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub param_scale: Scale,
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
pub logrows: u32,
/// The log_2 number of rows
#[arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, fixed, hashed, polycommit
#[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub input_visibility: Visibility,
/// Flags whether outputs are public, private, fixed, hashed, polycommit
#[arg(long, default_value = "public", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
pub output_visibility: Visibility,
/// Flags whether params are fixed, private, hashed, polycommit
#[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub param_visibility: Visibility,
#[arg(long, default_value = "false")]
/// Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// Should constants with 0.0 fraction be rebased to scale 0
#[arg(long, default_value = "false")]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub rebase_frac_zero_constants: bool,
/// check mode (safe, unsafe, etc)
#[arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
pub commitment: Option<Commitments>,
/// the base used for decompositions
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
pub decomp_base: usize,
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
/// the number of legs used for decompositions
pub decomp_legs: usize,
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// use unbounded lookup for the log
pub bounded_log_lookup: bool,
}
impl Default for RunArgs {
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
@@ -238,17 +350,18 @@ impl Default for RunArgs {
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
}
}
}
impl RunArgs {
///
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
pub fn validate(&self) -> Result<(), String> {
if self.param_visibility == Visibility::Public {
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
@@ -290,6 +403,7 @@ impl RunArgs {
}
/// Parse a single key-value pair
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn parse_key_val<T, U>(
s: &str,
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>

View File

@@ -76,7 +76,7 @@ pub fn init_logger() {
prefix_token(&record.level()),
// pretty print UTC time
chrono::Utc::now()
.format("%Y-%m-%d %H:%M:%S")
.format("%Y-%m-%d %H:%M:%S:%3f")
.to_string()
.bright_magenta(),
record.metadata().target(),

27
src/pfsys/errors.rs Normal file
View File

@@ -0,0 +1,27 @@
use thiserror::Error;
/// Error type for the pfsys module
#[derive(Error, Debug)]
pub enum PfsysError {
/// Failed to save the proof
#[error("failed to save the proof: {0}")]
SaveProof(String),
/// Failed to load the proof
#[error("failed to load the proof: {0}")]
LoadProof(String),
/// Halo2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
/// Failed to write point to transcript
#[error("failed to write point to transcript: {0}")]
WritePoint(String),
/// Invalid commitment scheme
#[error("invalid commitment scheme")]
InvalidCommitmentScheme,
/// Failed to load vk from file
#[error("failed to load vk from file: {0}")]
LoadVk(String),
/// Failed to load pk from file
#[error("failed to load pk from file: {0}")]
LoadPk(String),
}

View File

@@ -1,7 +1,7 @@
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
@@ -20,7 +20,7 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
@@ -200,7 +200,7 @@ impl AggregationConfig {
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);

View File

@@ -10,17 +10,14 @@ pub enum EvmVerificationError {
#[error("Solidity verifier found the proof invalid")]
InvalidProof,
/// If the Solidity verifier threw and error (e.g. OutOfGas)
#[error("Execution of Solidity code failed")]
SolidityExecution,
/// EVM execution errors
#[error("EVM execution of raw code failed")]
RawExecution,
#[error("Execution of Solidity code failed: {0}")]
SolidityExecution(String),
/// EVM verify errors
#[error("evm verification reverted")]
Reverted,
#[error("evm verification reverted: {0}")]
Reverted(String),
/// EVM verify errors
#[error("evm deployment failed")]
Deploy,
#[error("evm deployment failed: {0}")]
DeploymentFailed(String),
/// Invalid Visibility
#[error("Invalid visibility")]
InvalidVisibility,

Some files were not shown because too many files have changed in this diff Show More