Compare commits

...

41 Commits

Author SHA1 Message Date
github-actions[bot]
1da7cac21a ci: update version string in docs 2024-08-04 21:28:53 +00:00
dante
ada45a3197 chore: swap h2 collections hash for rustc hasher (#832) 2024-08-04 17:28:36 -04:00
dante
616b421967 fix: bump compiler to latest to accomodate latest serde diagnostic (#830) 2024-07-25 07:56:21 -04:00
dante
f64f0ecd23 fix: instance order when using processed params (#829) 2024-07-24 07:58:46 -04:00
dante
5be12b7a54 fix: num groups for conv operations should be specified at load time (#828) 2024-07-18 09:58:41 -04:00
dante
2fd877c716 chore: small worm example (#568) 2024-07-15 09:20:37 -04:00
dante
8197340985 chore: const filtering optimizations (#825) 2024-07-12 12:37:02 +01:00
dante
6855ea1947 feat: parallel polynomial reads in halo2 (#826) 2024-07-12 00:33:14 +01:00
dante
2ca57bde2c chore: bump tract (#823) 2024-06-29 01:53:18 +01:00
Ethan Cemer
390de88194 feat: create_evm_vk python bindings (#818) 2024-06-23 22:21:59 -04:00
dante
cd91f0af26 chore: allow for float lookup safety margins (#817) 2024-06-19 09:12:01 -04:00
dante
4771192823 fix: more verbose io / rw errors (#815) 2024-06-14 12:31:50 -04:00
dante
a863ccc868 chore: update cmd feature flag (#814) 2024-06-11 16:07:49 -04:00
Ethan Cemer
8e6ccc863d feat: all file source kzg commit DA (#812) 2024-06-11 09:32:54 -04:00
dante
00d6873f9a fix: should update using bash when possible (#813) 2024-06-10 12:13:59 -04:00
dante
c97ff84198 refactor: rm boxed errors (opaque) (#810) 2024-06-08 22:41:47 -04:00
dante
f5f8ef56f7 chore: ezkl self update (#809) 2024-06-07 10:30:21 -04:00
Ethan Cemer
685487c853 feat: DA swap proof commitments (#807) 2024-06-07 10:19:38 -04:00
dante
33d41c7e49 fix: calldata cmd 2024-05-31 13:19:33 -04:00
dante
e04c959662 fix: sequential dependence between JS releases (#806) 2024-05-31 08:34:51 -04:00
dante
27b1f2e9d4 fix: move banner to command loop 2024-05-30 17:01:40 -04:00
dante
4a172877af feat: add autocompletion tooling to cli (#805) 2024-05-30 16:44:24 -04:00
dante
5a8498894d feat: create encode-evm-calldata in cli and python (#803) 2024-05-29 21:03:16 -04:00
Snoppy
095c0ca5b4 chore: fix typos (#802) 2024-05-29 11:32:11 -04:00
dante
3fa482c9ef fix: remove error messages in calibration loop (confusing) (#800) 2024-05-22 09:32:52 +01:00
dante
6be3b1d663 chore: update alloy (#798) 2024-05-16 12:10:27 +09:00
dante
d5a1d1439c fix: elif syntax in install script 2024-05-14 03:31:26 +09:00
Ethan Cemer
ff8fd01f86 fix: patch invalid ${{ github.ref_name#v }} syntax on verify release script (#791) 2024-05-14 02:25:12 +09:00
dante
e9020f942e fix: python build matrix (#797) 2024-05-13 13:29:51 +09:00
dante
e7f54cb6ac fix: windows build (#796) 2024-05-13 12:47:36 +09:00
dante
ed65e8c090 Revert "fix: revert maturin version (#795)"
This reverts commit d9f2adad99.
2024-05-13 12:41:02 +09:00
dante
d9f2adad99 fix: revert maturin version (#795) 2024-05-13 12:21:35 +09:00
dante
5125aaa090 chore: add aarch64 linux to release pipeline (#788) 2024-05-13 11:29:49 +09:00
Jseam
f1950e6cd0 docs: add polycommit to RunArgs (#794) 2024-05-10 22:19:05 +09:00
dante
998ca22c2a chore: make most command struct args Option (#793) 2024-05-10 22:02:34 +09:00
dante
5c574adc31 chore: logistic regression example (#792) 2024-05-08 20:30:13 +09:00
dante
749e0ba652 chore: update h2 solidity verifier (#787) 2024-05-03 01:25:14 +01:00
dante
d464ddf6b6 chore: medium sized lstm example (#785) 2024-05-01 16:35:11 +01:00
dante
8f6c0aced5 chore: update tract (#784) 2024-04-30 13:31:33 +01:00
dante
860e9700a8 refactor!: swap integer rep to i64 from i128 (#781)
BREAKING CHANGE: may break w/ old compiled circuits
2024-04-26 16:16:55 -04:00
Ethan Cemer
32dd4a854f fix: patch npm package build failure (#782) 2024-04-25 10:38:06 -04:00
127 changed files with 8433 additions and 5083 deletions

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -30,7 +30,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -178,3 +178,58 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"$CLEANED_TAG\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Update pnpm-lock.yaml versions and integrity
run: |
awk -v integrity="$ENGINE_INTEGRITY" -v tag="$CLEANED_TAG" '
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -40,7 +40,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
@@ -50,6 +50,7 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Install built wheel
if: matrix.target == 'universal2-apple-darwin'
run: |
pip install ezkl --no-index --find-links dist --force-reinstall
python -c "import ezkl"
@@ -85,7 +86,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
@@ -110,7 +111,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64, i686]
target: [x86_64]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Checkout repo
@@ -102,28 +102,32 @@ jobs:
PCRE2_SYS_STATIC: 1
strategy:
matrix:
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu]
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu, linux-aarch64]
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2023-06-27
rust: nightly-2024-07-18
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2023-06-27
rust: nightly-2024-07-18
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2023-06-27
rust: nightly-2024-07-18
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2023-06-27
rust: nightly-2024-07-18
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2023-06-27
rust: nightly-2024-07-18
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-07-18
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
@@ -181,7 +185,7 @@ jobs:
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Strip release binary
if: matrix.build != 'windows-msvc'
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"
- name: Strip release binary (Windows)

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -189,7 +189,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -199,7 +199,7 @@ jobs:
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -212,7 +212,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -229,13 +229,15 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
@@ -258,6 +260,8 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
@@ -286,7 +290,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -304,7 +308,7 @@ jobs:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -326,7 +330,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -341,6 +345,12 @@ jobs:
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)
@@ -359,7 +369,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -367,7 +377,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -433,11 +443,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -467,7 +477,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -485,7 +495,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -502,7 +512,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -519,7 +529,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -529,7 +539,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -540,7 +550,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -560,7 +570,7 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Install cmake
@@ -570,11 +580,11 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest -vv
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
@@ -586,7 +596,7 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -635,7 +645,7 @@ jobs:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -645,7 +655,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies

View File

@@ -1,54 +0,0 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

2262
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,7 @@ halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curv
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.5.3", features = ["derive"] }
clap_complete = "4.5.2"
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
@@ -38,50 +39,49 @@ snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch =
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
semver = "1.0.22"
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = [
"ethers-solc",
] }
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
ethabi = "18"
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = [
reqwest = { version = "0.12.4", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = [
tokio = { version = "1.35", default_features = false, features = [
"macros",
"rt",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.20.2", features = [
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = [
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
@@ -103,8 +103,10 @@ console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -175,7 +177,7 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
default = ["ezkl", "mv-lookup", "no-banner", "parallel-poly-read"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = [
@@ -189,6 +191,7 @@ ezkl = [
"colored_json",
"halo2_proofs/circuit-params",
]
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
@@ -198,13 +201,15 @@ det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd", package = "halo2_proofs" }
[profile.release]

View File

@@ -91,9 +91,9 @@ You can install the library from source
cargo install --locked --path .
```
You will need a functioning installation of `solc` in order to run `ezkl` properly.
[solc-select](https://github.com/crytic/solc-select) is recommended.
Follow the instructions on [solc-select](https://github.com/crytic/solc-select) to activate `solc` in your environment.
`ezkl` now auto-manages solc installation for you.
#### building python bindings

View File

@@ -20,9 +20,9 @@
"name": "quantize_data",
"outputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"stateMutability": "pure",
@@ -31,9 +31,9 @@
{
"inputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"name": "to_field_element",

View File

@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();

View File

@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -84,7 +85,7 @@ fn runrelu(c: &mut Criterion) {
};
let input: Tensor<Value<Fr>> =
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),

View File

@@ -1,6 +1,167 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
import './LoadInstances.sol';
contract LoadInstances {
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
mload(add(add(encoded, add(i, 0x24)), instances_offset))
)
}
}
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from calldata
* @param encoded - verifier calldata
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
instances_length := calldataload(
add(add(encoded.offset, 0x04), instances_offset)
)
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
calldataload(
add(add(encoded.offset, add(i, 0x04)), instances_offset)
)
)
}
}
}
}
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
bytes constant COMMITMENT_KZG = hex"";
contract SwapProofCommitments {
/**
* @dev Swap the proof commitments
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function checkKzgCommits(
bytes calldata encoded
) internal pure returns (bool equal) {
bytes4 funcSig;
uint256 proof_offset;
uint256 proof_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch proof offset which is 4 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
proof_offset := calldataload(
add(
encoded.offset,
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
proof_length := calldataload(
add(add(encoded.offset, 0x04), proof_offset)
)
}
// Check the length of the commitment against the proof bytes
if (proof_length < COMMITMENT_KZG.length) {
return false;
}
// Load COMMITMENT_KZG into memory
bytes memory commitment = COMMITMENT_KZG;
// Compare the first N bytes of the proof with COMMITMENT_KZG
uint words = (commitment.length + 31) / 32; // Calculate the number of 32-byte words
assembly {
// Now we compare the commitment with the proof,
// ensuring that the commitments divided up into 32 byte words are all equal.
for {
let i := 0x20
} lt(i, add(mul(words, 0x20), 0x20)) {
i := add(i, 0x20)
} {
let wordProof := calldataload(
add(add(encoded.offset, add(i, 0x04)), proof_offset)
)
let wordCommitment := mload(add(commitment, i))
equal := eq(wordProof, wordCommitment)
if eq(equal, 0) {
return(0, 0)
}
}
}
return equal; // Return true if the commitment comparison passed
} /// end checkKzgCommits
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
@@ -13,9 +174,10 @@ import './LoadInstances.sol';
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestation is LoadInstances {
contract DataAttestation is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
@@ -34,11 +196,14 @@ contract DataAttestation is LoadInstances {
address public admin;
/**
* @notice EZKL P value
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER = uint256(0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001);
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_CALLS = 0;
@@ -69,7 +234,7 @@ contract DataAttestation is LoadInstances {
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if(_admin == address(0)) {
if (_admin == address(0)) {
revert();
}
admin = _admin;
@@ -80,7 +245,7 @@ contract DataAttestation is LoadInstances {
bytes[][] memory _callData,
uint256[][] memory _decimals
) external {
require(msg.sender == admin, "Only admin can update instanceOffset");
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
@@ -111,7 +276,10 @@ contract DataAttestation is LoadInstances {
// count the total number of storage reads across all of the accounts
counter += _callData[i].length;
}
require(counter == INPUT_CALLS + OUTPUT_CALLS, "Invalid number of calls");
require(
counter == INPUT_CALLS + OUTPUT_CALLS,
"Invalid number of calls"
);
}
function mulDiv(
@@ -167,7 +335,7 @@ contract DataAttestation is LoadInstances {
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param data - The data returned from the account calls.
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
* @param scale - The scale used to convert the floating point value into a fixed point value.
* @param scale - The scale used to convert the floating point value into a fixed point value.
*/
function quantizeData(
bytes memory data,
@@ -181,7 +349,7 @@ contract DataAttestation is LoadInstances {
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
output += 1;
}
quantized_data = neg ? -int256(output): int256(output);
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
@@ -211,7 +379,9 @@ contract DataAttestation is LoadInstances {
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(int256 x) internal pure returns (uint256 field_element) {
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
@@ -251,13 +421,18 @@ contract DataAttestation is LoadInstances {
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0,"Address: call to non-contract");
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);

View File

@@ -1,92 +0,0 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
contract LoadInstances {
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
mload(add(add(encoded, add(i, 0x24)), instances_offset))
)
}
}
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from calldata
* @param encoded - verifier calldata
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
instances_length := calldataload(add(add(encoded.offset, 0x04), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly{
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
calldataload(
add(add(encoded.offset, add(i, 0x04)), instances_offset)
)
)
}
}
}
}

View File

@@ -1,135 +0,0 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.17;
contract QuantizeData {
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same instance, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
/**
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or denominator == 0
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv)
* with further edits by Uniswap Labs also under MIT license.
*/
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use
// use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2^256 + prod0.
uint256 prod0; // Least significant 256 bits of the product
uint256 prod1; // Most significant 256 bits of the product
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
// Handle non-overflow cases, 256 by 256 division.
if (prod1 == 0) {
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
// The surrounding unchecked block does not change this fact.
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
return prod0 / denominator;
}
// Make sure the result is less than 2^256. Also prevents denominator == 0.
require(denominator > prod1, "Math: mulDiv overflow");
///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////
// Make division exact by subtracting the remainder from [prod1 prod0].
uint256 remainder;
assembly {
// Compute remainder using mulmod.
remainder := mulmod(x, y, denominator)
// Subtract 256 bit number from 512 bit number.
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
// Factor powers of two out of denominator and compute largest power of two divisor of denominator. Always >= 1.
// See https://cs.stackexchange.com/q/138556/92363.
// Does not overflow because the denominator cannot be zero at this stage in the function.
uint256 twos = denominator & (~denominator + 1);
assembly {
// Divide denominator by twos.
denominator := div(denominator, twos)
// Divide [prod1 prod0] by twos.
prod0 := div(prod0, twos)
// Flip twos such that it is 2^256 / twos. If twos is zero, then it becomes one.
twos := add(div(sub(0, twos), twos), 1)
}
// Shift in bits from prod1 into prod0.
prod0 |= prod1 * twos;
// Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such
// that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for
// four bits. That is, denominator * inv = 1 mod 2^4.
uint256 inverse = (3 * denominator) ^ 2;
// Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works
// in modular arithmetic, doubling the correct bits in each step.
inverse *= 2 - denominator * inverse; // inverse mod 2^8
inverse *= 2 - denominator * inverse; // inverse mod 2^16
inverse *= 2 - denominator * inverse; // inverse mod 2^32
inverse *= 2 - denominator * inverse; // inverse mod 2^64
inverse *= 2 - denominator * inverse; // inverse mod 2^128
inverse *= 2 - denominator * inverse; // inverse mod 2^256
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
// This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is
// less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1
// is no longer required.
result = prod0 * inverse;
return result;
}
}
function quantize_data(
bytes[] memory data,
uint256[] memory decimals,
uint256[] memory scales
) external pure returns (int256[] memory quantized_data) {
quantized_data = new int256[](data.length);
for (uint i; i < data.length; i++) {
int x = abi.decode(data[i], (int256));
bool neg = x < 0;
if (neg) x = -x;
uint denom = 10 ** decimals[i];
uint scale = 1 << scales[i];
uint output = mulDiv(uint256(x), scale, denom);
if (mulmod(uint256(x), scale, denom) * 2 >= denom) {
output += 1;
}
quantized_data[i] = neg ? -int256(output) : int256(output);
}
}
function to_field_element(
int128[] memory quantized_data
) public pure returns (uint256[] memory output) {
output = new uint256[](quantized_data.length);
for (uint i; i < quantized_data.length; i++) {
output[i] = uint256(quantized_data[i] + int(ORDER)) % ORDER;
}
}
}

View File

@@ -1,12 +0,0 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.17;
contract TestReads {
int[] public arr;
constructor(int256[] memory _numbers) {
for (uint256 i = 0; i < _numbers.length; i++) {
arr.push(_numbers[i]);
}
}
}

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl==12.0.2
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '12.0.2'
version = release

View File

@@ -2,8 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
@@ -42,8 +41,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -66,8 +65,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -90,8 +89,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -205,6 +204,7 @@ where
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config
@@ -315,7 +315,11 @@ pub fn runconv() {
.test_set_length(10_000)
.finalize();
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
let mut train_data = Tensor::from(
trn_img
.iter()
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
);
train_data.reshape(&[50_000, 28, 28]).unwrap();
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
@@ -343,8 +347,8 @@ pub fn runconv() {
.map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}),
);
@@ -355,7 +359,8 @@ pub fn runconv() {
let l0_kernels = l0_kernels.try_into().unwrap();
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
let mut l0_bias =
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let l0_bias = l0_bias.try_into().unwrap();
@@ -363,8 +368,8 @@ pub fn runconv() {
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
@@ -374,8 +379,8 @@ pub fn runconv() {
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
@@ -401,13 +406,13 @@ pub fn runconv() {
l2_params: [l2_weights, l2_biases],
};
let public_input: Tensor<i32> = vec![
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
let public_input: Tensor<IntegerRep> = vec![
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
]
.into_iter()
.into();
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
println!("MOCK PROVING");
let now = Instant::now();

View File

@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: i128, const LOOKUP_MAX: i128> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;
@@ -215,33 +215,33 @@ pub fn runmlp() {
#[cfg(not(target_arch = "wasm32"))]
env_logger::init();
// parameters
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
// input data, with 1 padding to allow for bias
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
.unwrap()
.into();
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
let circuit = MyCircuit::<4, -8192, 8192> {
@@ -251,12 +251,12 @@ pub fn runmlp() {
_marker: PhantomData,
};
let public_input: Vec<i32> = unsafe {
let public_input: Vec<IntegerRep> = unsafe {
vec![
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
(2849f32 / 128f32).to_int_unchecked::<i32>(),
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
]
};
@@ -265,7 +265,10 @@ pub fn runmlp() {
let prover = MockProver::run(
K as u32,
&circuit,
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
vec![public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect()],
)
.unwrap();
prover.assert_satisfied();

View File

@@ -251,7 +251,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"ezkl.setup_test_evm_witness(\n",
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
@@ -333,7 +333,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -354,7 +354,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -462,7 +462,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -482,7 +482,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -510,7 +510,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -535,7 +535,7 @@
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -567,7 +567,7 @@
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
},
"orig_nbformat": 4
},

View File

@@ -249,7 +249,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -278,7 +278,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -299,7 +299,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n"
]
},
@@ -518,7 +518,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -538,7 +538,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -566,7 +566,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -591,7 +591,7 @@
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -623,7 +623,7 @@
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
@@ -654,4 +654,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -0,0 +1,604 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-kzg-vis\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source and the params and outputs are committed to using kzg-commitments. \n",
"\n",
"In this setup:\n",
"- the inputs and outputs are publicly known to the prover and verifier\n",
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n",
"- The kzg commitment to the params and inputs will be read from the proof and checked to make sure it matches the expected commitment stored on-chain.\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"polycommitment\" \n",
"- `output_visibility`: \"polycommitment\"\n",
"\n",
"**Note**:\n",
"When we set this to polycommitment, we are saying that the model parameters are committed to using a polynomial commitment scheme. This commitment will be stored on chain as a constant stored in the DA contract, and the proof will contain the commitment to the parameters. The DA verification will then check that the commitment in the proof matches the commitment stored on chain. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"polycommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
" data_path,\n",
" input_source=ezkl.PyTestDataSource.OnChain,\n",
" output_source=ezkl.PyTestDataSource.File,\n",
" rpc_url=RPC_URL)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"When deploying a DA with kzg commitments, we need to make sure to also pass a witness file that contains the commitments to the parameters and inputs. This is because the verifier will need to check that the commitments in the proof match the commitments stored on chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" witness_path = witness_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -150,7 +150,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -170,7 +170,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -192,7 +192,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -303,4 +303,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -352,14 +352,8 @@
"# Specify all the files we need\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.ezkl')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')\n",
"cal_data_path = os.path.join('cal_data.json')"
"cal_data_path = os.path.join('calibration.json')"
]
},
{
@@ -424,7 +418,7 @@
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"res = ezkl.gen_settings()\n",
"assert res == True\n",
"\n"
]
@@ -443,7 +437,7 @@
"\n",
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
"# You may want to increase the max logrows if accuracy is a concern\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
"res = await ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])"
]
},
{
@@ -463,7 +457,7 @@
},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
@@ -484,7 +478,7 @@
},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs()"
]
},
{
@@ -504,17 +498,10 @@
},
"outputs": [],
"source": [
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"res = ezkl.setup()\n",
"\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
"assert res == True"
]
},
{
@@ -539,7 +526,7 @@
"# now generate the witness file\n",
"witness_path = os.path.join('witness.json')\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness()\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -559,13 +546,7 @@
"\n",
"proof_path = os.path.join('proof.json')\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
"\n",
"print(proof)\n",
"assert os.path.isfile(proof_path)"
@@ -585,11 +566,7 @@
"source": [
"# verify our proof\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
@@ -664,12 +641,9 @@
"sol_code_path = os.path.join('Verifier.sol')\n",
"abi_path = os.path.join('Verifier.abi')\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path\n",
"res = await ezkl.create_evm_verifier(\n",
" sol_code_path=sol_code_path,\n",
" abi_path=abi_path, \n",
" )\n",
"\n",
"assert res == True\n",
@@ -757,7 +731,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -467,7 +467,7 @@
"outputs": [],
"source": [
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -494,7 +494,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -508,7 +508,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -625,4 +625,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -195,7 +195,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -222,7 +222,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -236,7 +236,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -179,7 +179,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -202,7 +202,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -214,7 +214,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -313,4 +313,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -241,7 +241,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -270,7 +270,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -291,7 +291,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -420,7 +420,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -451,7 +451,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -472,7 +472,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",

View File

@@ -152,7 +152,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -175,7 +175,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs(settings_path = settings_path)"
"res = await ezkl.get_srs(settings_path = settings_path)"
]
},
{
@@ -188,7 +188,7 @@
"# now generate the witness file \n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -284,4 +284,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -155,7 +155,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -178,7 +178,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -190,7 +190,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -289,4 +289,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -233,7 +233,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -262,7 +262,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -315,7 +315,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
]
},
{
@@ -429,7 +429,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -460,7 +460,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -481,7 +481,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",

View File

@@ -193,7 +193,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -216,7 +216,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -228,7 +228,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -347,4 +347,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -142,7 +142,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -165,7 +165,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -177,7 +177,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -276,4 +276,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -347,7 +347,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -370,7 +370,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -383,7 +383,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -490,4 +490,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -0,0 +1,279 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Logistic Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package for a Logistic Regression model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LogisticRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LogisticRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -139,7 +139,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -180,7 +180,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -193,7 +193,7 @@
"# now generate the witness file \n",
"witness_path = \"lstmwitness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -300,4 +300,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -91,11 +91,14 @@
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"subprocess.Popen(command)\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(5)\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
@@ -310,7 +313,7 @@
"\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
@@ -387,7 +390,7 @@
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = ezkl.gen_witness(\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
@@ -425,16 +428,6 @@
"assert os.path.isfile(proof_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# kill all shovel process \n",
"os.system(\"pkill -f shovel\")"
]
}
],
"metadata": {

View File

@@ -323,7 +323,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"assert res == True"
]
},
@@ -348,7 +348,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs(settings_path)"
"res = await ezkl.get_srs(settings_path)"
]
},
{
@@ -362,7 +362,7 @@
"# now generate the witness file\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -469,7 +469,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test_1.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -502,7 +502,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -525,7 +525,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -289,7 +289,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
]
},
{
@@ -309,7 +309,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -321,7 +321,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

File diff suppressed because one or more lines are too long

View File

@@ -215,7 +215,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -235,7 +235,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -247,7 +247,7 @@
"# now generate the witness file\n",
"witness_path = \"ae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -451,7 +451,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n",
"print(\"verified\")"
]
@@ -473,7 +473,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -485,7 +485,7 @@
"# now generate the witness file \n",
"witness_path = \"vae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -845,7 +845,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"assert res == True"
]
},
@@ -870,7 +870,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -881,7 +881,7 @@
},
"outputs": [],
"source": [
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -993,4 +993,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -261,7 +261,7 @@
"source": [
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
"\n",
"def setup(i):\n",
"async def setup(i):\n",
" # file names\n",
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
@@ -282,7 +282,7 @@
"\n",
" # generate settings for the current model\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
" res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" assert res == True\n",
"\n",
" # load settings and print them to the console\n",
@@ -303,11 +303,11 @@
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
"\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
"for i in range(2):\n",
" setup(i)\n"
" await setup(i)\n"
]
},
{
@@ -414,7 +414,7 @@
"outputs": [],
"source": [
"for i in range(2):\n",
" setup(i)"
" await setup(i)"
]
},
{
@@ -466,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
},
"orig_nbformat": 4
},

View File

@@ -174,7 +174,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -196,7 +196,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -208,7 +208,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -215,7 +215,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -229,7 +229,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -265,7 +265,7 @@
" # Serialize data into file:\n",
"json.dump( data, open(data_path_faulty, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"assert os.path.isfile(witness_path_faulty)"
]
},
@@ -310,7 +310,7 @@
"# Serialize data into file:\n",
"json.dump( data, open(data_path_truthy, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"res = await ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"assert os.path.isfile(witness_path_truthy)"
]
},
@@ -482,7 +482,7 @@
"source": [
"import pytest\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path_faulty,\n",
" settings_path,\n",
@@ -514,7 +514,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,

View File

@@ -193,7 +193,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -205,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -290,7 +290,7 @@
"source": [
"# Generate a larger SRS. This is needed for the aggregated proof\n",
"\n",
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
]
},
{
@@ -374,7 +374,7 @@
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
"\n",
"res = ezkl.create_evm_verifier_aggr(\n",
"res = await ezkl.create_evm_verifier_aggr(\n",
" [settings_path],\n",
" aggregate_vk_path,\n",
" sol_code_path,\n",
@@ -404,4 +404,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -157,6 +157,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b78d3cbf",
"metadata": {},
"outputs": [],
"source": [
@@ -170,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -204,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -298,7 +299,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,

View File

@@ -169,7 +169,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -191,7 +191,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -203,7 +203,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -170,7 +170,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -192,7 +192,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -149,7 +149,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -171,7 +171,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -183,7 +183,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -282,4 +282,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -250,7 +250,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -297,7 +297,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
@@ -411,7 +411,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
@@ -478,12 +478,11 @@
"import pytest\n",
"\n",
"def test_verification():\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
" ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"# Run the test function\n",
@@ -510,9 +509,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -167,7 +167,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -187,7 +187,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -209,7 +209,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -221,7 +221,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -320,4 +320,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -39,7 +39,7 @@
"import json\n",
"import numpy as np\n",
"from sklearn.svm import SVC\n",
"import sk2torch\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
@@ -59,11 +59,11 @@
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = SVC(probability=True)\n",
"sk_model.fit(xs, ys)\n",
"model = sk2torch.wrap(sk_model)\n",
"\n",
"model = convert(sk_model, \"torch\").model\n",
"\n",
"\n",
"\n",
"model\n",
"\n"
]
},
@@ -84,33 +84,6 @@
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ca328",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"\n",
"# Compute the gradients of the SVM output.\n",
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
"\n",
"\n",
"# Create a quiver plot of the vector field.\n",
"plt.quiver(\n",
" grid_xs[:, 0].detach().numpy(),\n",
" grid_xs[:, 1].detach().numpy(),\n",
" input_grads[:, 0].detach().numpy(),\n",
" input_grads[:, 1].detach().numpy(),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -119,14 +92,14 @@
"outputs": [],
"source": [
"\n",
"\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
@@ -143,9 +116,7 @@
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"data = dict(input_data=[d])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -167,6 +138,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0bee4d7f",
"metadata": {},
"outputs": [],
"source": [
@@ -180,7 +152,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -202,7 +174,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -214,13 +186,13 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
@@ -420,7 +392,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
}
@@ -441,7 +413,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -637,7 +637,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -646,7 +646,7 @@
"metadata": {},
"outputs": [],
"source": [
"ezkl.get_srs( settings_path)"
"await ezkl.get_srs( settings_path)"
]
},
{
@@ -683,7 +683,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{

View File

@@ -525,7 +525,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{
@@ -572,7 +572,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{

View File

@@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
@@ -49,7 +49,11 @@
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os"
"import os\n",
"\n",
"import logging\n",
"\n",
"logging.basicConfig(level=logging.INFO)"
]
},
{
@@ -63,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -71,7 +75,15 @@
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"source": [
"from torch import nn\n",
"import torch\n",
@@ -133,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -141,7 +153,18 @@
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1715422870\n",
"1714818070\n",
"https://api.coingecko.com/api/v3/coins/ethereum/market_chart/range?vs_currency=usd&from=1714818070&to=1715422870\n",
"<Response [200]>\n"
]
}
],
"source": [
"\n",
"def get_url(coin, currency, start, end):\n",
@@ -174,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -183,7 +206,115 @@
"id": "WSj1Uxln65vf",
"outputId": "51422d71-9680-4b51-c4df-e400d20f988b"
},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>time</th>\n",
" <th>prices</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1714820485367</td>\n",
" <td>3146.785806</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1714824033868</td>\n",
" <td>3127.968728</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1714828058243</td>\n",
" <td>3156.141681</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1714831650751</td>\n",
" <td>3124.834064</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1714834972229</td>\n",
" <td>3133.115333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163</th>\n",
" <td>1715407579346</td>\n",
" <td>2918.049749</td>\n",
" </tr>\n",
" <tr>\n",
" <th>164</th>\n",
" <td>1715411090715</td>\n",
" <td>2920.330834</td>\n",
" </tr>\n",
" <tr>\n",
" <th>165</th>\n",
" <td>1715414554830</td>\n",
" <td>2923.986611</td>\n",
" </tr>\n",
" <tr>\n",
" <th>166</th>\n",
" <td>1715418419843</td>\n",
" <td>2910.537671</td>\n",
" </tr>\n",
" <tr>\n",
" <th>167</th>\n",
" <td>1715421675338</td>\n",
" <td>2907.702307</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>168 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" time prices\n",
"0 1714820485367 3146.785806\n",
"1 1714824033868 3127.968728\n",
"2 1714828058243 3156.141681\n",
"3 1714831650751 3124.834064\n",
"4 1714834972229 3133.115333\n",
".. ... ...\n",
"163 1715407579346 2918.049749\n",
"164 1715411090715 2920.330834\n",
"165 1715414554830 2923.986611\n",
"166 1715418419843 2910.537671\n",
"167 1715421675338 2907.702307\n",
"\n",
"[168 rows x 2 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(new_data)\n",
"df\n"
@@ -200,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -225,7 +356,98 @@
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.execute:SRS already exists at that path\n",
"INFO:ezkl.execute:num calibration batches: 1\n",
"INFO:ezkl.execute:read 16777476 bytes from file (vector of len = 16777476)\n",
"WARNING:ezkl.execute:\n",
"\n",
" <------------- Numerical Fidelity Report (input_scale: 4, param_scale: 4, scale_input_multiplier: 10) ------------->\n",
"\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"| mean_error | median_error | max_error | min_error | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"| -727.9929 | -727.9929 | -727.9929 | -727.9929 | 727.9929 | 727.9929 | 727.9929 | 727.9929 | 529973.7 | -0.24999964 | 0.24999964 |\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"\n",
"\n",
"INFO:ezkl.execute:file hash: 41509f380362a8d14401c5ae92073154922fe23e45459ce6f696f58607655db7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"run_args\": {\n",
" \"tolerance\": {\n",
" \"val\": 0.0,\n",
" \"scale\": 1.0\n",
" },\n",
" \"input_scale\": 4,\n",
" \"param_scale\": 4,\n",
" \"scale_rebase_multiplier\": 10,\n",
" \"lookup_range\": [\n",
" 0,\n",
" 0\n",
" ],\n",
" \"logrows\": 6,\n",
" \"num_inner_cols\": 2,\n",
" \"variables\": [\n",
" [\n",
" \"batch_size\",\n",
" 1\n",
" ]\n",
" ],\n",
" \"input_visibility\": \"Private\",\n",
" \"output_visibility\": \"Public\",\n",
" \"param_visibility\": \"Private\",\n",
" \"div_rebasing\": false,\n",
" \"rebase_frac_zero_constants\": false,\n",
" \"check_mode\": \"UNSAFE\",\n",
" \"commitment\": \"KZG\"\n",
" },\n",
" \"num_rows\": 21,\n",
" \"total_assignments\": 42,\n",
" \"total_const_size\": 0,\n",
" \"total_dynamic_col_size\": 0,\n",
" \"num_dynamic_lookups\": 0,\n",
" \"num_shuffles\": 0,\n",
" \"total_shuffle_col_size\": 0,\n",
" \"model_instance_shapes\": [\n",
" [\n",
" 1\n",
" ]\n",
" ],\n",
" \"model_output_scales\": [\n",
" 8\n",
" ],\n",
" \"model_input_scales\": [\n",
" 4\n",
" ],\n",
" \"module_sizes\": {\n",
" \"polycommit\": [],\n",
" \"poseidon\": [\n",
" 0,\n",
" [\n",
" 0\n",
" ]\n",
" ]\n",
" },\n",
" \"required_lookups\": [],\n",
" \"required_range_checks\": [],\n",
" \"check_mode\": \"UNSAFE\",\n",
" \"version\": \"0.0.0\",\n",
" \"num_blinding_factors\": null,\n",
" \"timestamp\": 1715422871248\n",
"}\n"
]
}
],
"source": [
"# generate settings\n",
"onnx_filename = os.path.join('lol.onnx')\n",
@@ -236,9 +458,9 @@
"\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
"await ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
"res = ezkl.get_srs(settings_filename)\n",
"res = await ezkl.get_srs(settings_filename)\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
"\n",
"# show the settings.json\n",
@@ -259,7 +481,24 @@
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:VK took 0.8\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:PK took 0.2\n",
"INFO:ezkl.pfsys:saving verification key 💾\n",
"INFO:ezkl.pfsys:done saving verification key ✅\n",
"INFO:ezkl.pfsys:saving proving key 💾\n",
"INFO:ezkl.pfsys:done saving proving key ✅\n"
]
}
],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
@@ -281,20 +520,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"res = await ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -302,7 +541,34 @@
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys:loading proving key from \"test.pk\"\n",
"INFO:ezkl.pfsys:done loading proving key ✅\n",
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:proof started...\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:proof took 0.15\n",
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
"INFO:ezkl.pfsys:done loading verification key ✅\n",
"INFO:ezkl.execute:verify took 0.2\n",
"INFO:ezkl.execute:verified: true\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"verified\n"
]
}
],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
@@ -351,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -360,7 +626,26 @@
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
"INFO:ezkl.pfsys:done loading verification key ✅\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test.vk\n",
"settings.json\n"
]
}
],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
@@ -370,7 +655,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
@@ -381,9 +666,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.eth:using chain 31337\n",
"INFO:ezkl.execute:Contract deployed at: 0x998abeb3e57409262ae5b751f60747921b33613e\n"
]
}
],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
@@ -391,8 +685,9 @@
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"sol_code_path = 'test.sol'\n",
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -406,16 +701,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.eth:using chain 31337\n",
"INFO:ezkl.eth:estimated verify gas cost: 399775\n",
"INFO:ezkl.execute:Solidity verification result: true\n"
]
}
],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -451,7 +756,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1 @@
[{"type":"function","name":"verifyProof","inputs":[{"name":"proof","type":"bytes","internalType":"bytes"},{"name":"instances","type":"uint256[]","internalType":"uint256[]"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"nonpayable"}]

View File

@@ -629,7 +629,7 @@
"source": [
"\n",
"\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -660,7 +660,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs(settings_path)"
"res = await ezkl.get_srs(settings_path)"
]
},
{
@@ -680,7 +680,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -807,7 +807,7 @@
"settings_path = os.path.join('settings.json')\n",
"\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -847,7 +847,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -868,7 +868,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -242,6 +242,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "2007dc77",
"metadata": {},
"outputs": [],
"source": [
@@ -257,6 +258,7 @@
},
{
"cell_type": "markdown",
"id": "ab993958",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
@@ -272,7 +274,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -284,12 +286,13 @@
"source": [
"# now generate the witness file \n",
"\n",
"witness = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"witness = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "markdown",
"id": "ad58432e",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
@@ -317,6 +320,7 @@
},
{
"cell_type": "markdown",
"id": "1746c8d1",
"metadata": {},
"source": [
"We can now create an EVM verifier contract from our circuit. This contract will be deployed to the chain we are using. In this case we are using a local anvil instance."
@@ -325,15 +329,15 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d1920c0f",
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
@@ -344,6 +348,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0fd7f22b",
"metadata": {},
"outputs": [],
"source": [
@@ -351,7 +356,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -362,6 +367,7 @@
},
{
"cell_type": "markdown",
"id": "9c0dffab",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. \n",
@@ -371,6 +377,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cc888848",
"metadata": {},
"outputs": [],
"source": []
@@ -378,6 +385,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c2db14d7",
"metadata": {},
"outputs": [],
"source": [
@@ -385,7 +393,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -396,12 +404,13 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5a018ba6",
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -412,6 +421,7 @@
},
{
"cell_type": "markdown",
"id": "2adad845",
"metadata": {},
"source": [
"Now we can pull in the data from the contract and calculate a new set of coordinates. We then rotate the world by 1 transform and submit the proof to the contract. The contract could then update the world rotation (logic not inserted here). For demo purposes we do this repeatedly, rotating the world by 1 transform. "
@@ -444,6 +454,7 @@
},
{
"cell_type": "markdown",
"id": "90eda56e",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
@@ -528,7 +539,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -193,7 +193,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -215,7 +215,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -227,7 +227,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -346,4 +346,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,9 @@
{
"input_data": [
[
1.514470100402832, 1.519423007965088, 1.5182757377624512,
1.5262789726257324, 1.5298409461975098
]
],
"output_data": [[-0.1862019]]
}

Binary file not shown.

View File

@@ -104,5 +104,5 @@ json.dump(data, open("input.json", 'w'))
# ezkl.gen_settings("network.onnx", "settings.json")
# !RUST_LOG = full
# res = ezkl.calibrate_settings(
# res = await ezkl.calibrate_settings(
# "input.json", "network.onnx", "settings.json", "resources")

View File

@@ -0,0 +1 @@
network.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,47 @@
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
cp input.json calibration.json
ezkl calibrate-settings
ezkl compile-circuit
ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
size 82095882

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -1,6 +1,6 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"version": "v10.4.2",
"publishConfig": {
"access": "public"
},
@@ -27,7 +27,7 @@
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ezkljs/engine": "^9.4.4",
"@ezkljs/engine": "10.4.2",
"ethers": "6.7.1",
"json-bigint": "1.0.0"
},

View File

@@ -27,8 +27,8 @@ dependencies:
specifier: 5.7.0
version: 5.7.0
'@ezkljs/engine':
specifier: ^9.4.4
version: 9.4.4
specifier: "10.4.2"
version: "10.4.2"
ethers:
specifier: 6.7.1
version: 6.7.1
@@ -397,8 +397,8 @@ packages:
'@ethersproject/strings': 5.7.0
dev: false
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
/@ezkljs/engine@10.4.2:
resolution: {integrity: "sha512-1GNB4vChbaQ1ALcYbEbM/AFoh4QWtswpzGCO/g9wL8Ep6NegM2gQP/uWICU7Utl0Lj1DncXomD7PUhFSXhtx8A=="}
dependencies:
'@types/json-bigint': 1.0.2
json-bigint: 1.0.0

View File

@@ -99,6 +99,10 @@ fi
echo "Removing old ezkl binary if it exists"
[ -e file ] && rm file
# echo platform and architecture
echo "Platform: $PLATFORM"
echo "Architecture: $ARCHITECTURE"
# download the release and unpack the right tarball
if [ "$PLATFORM" == "windows-msvc" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
@@ -126,7 +130,6 @@ elif [ "$PLATFORM" == "macos" ]; then
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
else
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
@@ -143,7 +146,7 @@ elif [ "$PLATFORM" == "macos" ]; then
fi
elif [ "$PLATFORM" == "linux" ]; then
if [ "${ARCHITECTURE}" = "amd64" ]; then
if [ "$ARCHITECTURE" == "amd64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
@@ -155,9 +158,20 @@ elif [ "$PLATFORM" == "linux" ]; then
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
elif [ "$ARCHITECTURE" == "aarch64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-aarch64.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
else
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
echo "Non aarch ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
else

View File

@@ -8,6 +8,7 @@ addopts = "-rfEX -p pytester --strict-markers"
testpaths = [
"tests/python/*_tests.py",
]
asyncio_mode = "auto"
[project]
name = "ezkl"

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-02-06"
channel = "nightly-2024-07-18"
components = ["rustfmt", "clippy"]

View File

@@ -1,7 +1,7 @@
// ignore file if compiling for wasm
#[cfg(not(target_arch = "wasm32"))]
use clap::Parser;
use clap::{CommandFactory, Parser};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(not(target_arch = "wasm32"))]
@@ -11,35 +11,49 @@ use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, error, info};
use log::{error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "icicle")]
use std::env;
#[cfg(not(target_arch = "wasm32"))]
use std::error::Error;
#[tokio::main(flavor = "current_thread")]
#[cfg(not(target_arch = "wasm32"))]
pub async fn main() -> Result<(), Box<dyn Error>> {
pub async fn main() {
let args = Cli::parse();
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
info!("Running with ICICLE GPU");
if let Some(generator) = args.generator {
ezkl::commands::print_completions(generator, &mut Cli::command());
} else if let Some(command) = args.command {
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
info!("Running with ICICLE GPU");
} else {
info!("Running with CPU");
}
info!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);
let res = run(command).await;
match &res {
Ok(_) => {
info!("succeeded");
}
Err(e) => {
error!("{}", e);
std::process::exit(1)
}
}
} else {
info!("Running with CPU");
init_logger();
error!("No command provided");
std::process::exit(1)
}
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
let res = run(args.command).await;
match &res {
Ok(_) => info!("succeeded"),
Err(e) => error!("failed: {}", e),
};
res.map(|_| ())
}
#[cfg(target_arch = "wasm32")]

View File

@@ -0,0 +1,25 @@
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum ModuleError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Wrong input type for a module
#[error("wrong input type {0} must be {1}")]
WrongInputType(String, String),
/// A constant was not previously assigned
#[error("constant was not previously assigned")]
ConstantNotAssigned,
/// Input length is wrong
#[error("input length is wrong {0}")]
InputWrongLength(usize),
}
impl From<ModuleError> for PlonkError {
fn from(_e: ModuleError) -> PlonkError {
PlonkError::Synthesis
}
}

View File

@@ -6,10 +6,11 @@ pub mod polycommit;
///
pub mod planner;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Error},
};
///
pub mod errors;
use halo2_proofs::{circuit::Layouter, plonk::ConstraintSystem};
use halo2curves::ff::PrimeField;
pub use planner::*;
@@ -35,14 +36,14 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Name
fn name(&self) -> &'static str;
/// Run the operation the module represents
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>>;
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, errors::ModuleError>;
/// Layout inputs
fn layout_inputs(
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
) -> Result<Self::InputAssignments, errors::ModuleError>;
/// Layout
fn layout(
&self,
@@ -50,7 +51,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
) -> Result<ValTensor<F>, errors::ModuleError>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;
/// Number of rows used by the module

View File

@@ -18,6 +18,7 @@ use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the PolyCommit hash function
@@ -110,7 +111,7 @@ impl Module<Fp> for PolyCommitChip {
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
Ok(())
}
@@ -123,28 +124,30 @@ impl Module<Fp> for PolyCommitChip {
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
layouter
.assign_region(
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
.map_err(|e| e.into())
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
Ok(vec![message])
}

View File

@@ -21,6 +21,7 @@ use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::errors::ModuleError;
use super::Module;
/// The number of instance columns used by the Poseidon hash function
@@ -174,7 +175,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
) -> Result<Self::InputAssignments, ModuleError> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
@@ -185,78 +186,82 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let res = layouter.assign_region(
|| "load message",
|mut region| {
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, Error> = match &message {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
match &message {
ValTensor::Value { inner: v, .. } => {
v.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
match value {
ValType::Value(v) => region.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v)
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
constants.insert(
*f,
ValType::AssignedConstant(res.clone(), *f),
);
Ok(res)
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
)),
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
e
);
Err(Error::Synthesis)
}
}
})
.collect(),
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect()
}
};
})
.collect()
}
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
let offset = message.len() / WIDTH + 1;
@@ -277,7 +282,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
message.len(),
start_time.elapsed()
);
res
res.map_err(|e| e.into())
}
/// L is the number of inputs to the hash function
@@ -289,7 +294,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
) -> Result<ValTensor<Fp>, ModuleError> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
@@ -301,7 +306,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let mut one_iter = false;
// do the Tree dance baby
while input_cells.len() > 1 || !one_iter {
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, Error> = input_cells
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
.chunks(L)
.enumerate()
.map(|(i, block)| {
@@ -332,7 +337,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
hash
})
.collect();
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into());
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
one_iter = true;
@@ -348,7 +354,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) => v,
_ => {
log::error!("wrong input type, must be previously assigned");
return Err(Error::Synthesis);
return Err(Error::Synthesis.into());
}
};
@@ -380,7 +386,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
let mut hash_inputs = message;
let len = hash_inputs.len();
@@ -400,7 +406,11 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
block.extend(vec![Fp::ZERO; L - remainder].iter());
}
let message = block.try_into().map_err(|_| Error::Synthesis)?;
let block_len = block.len();
let message = block
.try_into()
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
Ok(halo2_gadgets::poseidon::primitives::Hash::<
_,
@@ -411,7 +421,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
>::init()
.hash(message))
})
.collect::<Result<Vec<_>, Error>>()?;
.collect::<Result<Vec<_>, ModuleError>>()?;
one_iter = true;
hash_inputs = hashes;
}

View File

@@ -1,7 +1,5 @@
use std::str::FromStr;
use thiserror::Error;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
@@ -26,31 +24,11 @@ use crate::{
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, Op};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
/// circuit related errors.
#[derive(Debug, Error)]
pub enum CircuitError {
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
}
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -80,14 +58,18 @@ impl ToFlags for CheckMode {
impl From<String> for CheckMode {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
"safe" => CheckMode::SAFE,
"unsafe" => CheckMode::UNSAFE,
_ => {
log::error!("Invalid value for CheckMode");
log::warn!("defaulting to SAFE");
CheckMode::SAFE
}
Self::from_str(value.as_str()).unwrap()
}
}
impl FromStr for CheckMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"safe" => Ok(CheckMode::SAFE),
"unsafe" => Ok(CheckMode::UNSAFE),
_ => Err("Invalid value for CheckMode".to_string()),
}
}
}
@@ -509,18 +491,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
return Err(CircuitError::WrongColumnType(index.name().to_string()));
}
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
if !output.is_advice() {
return Err("wrong input type for lookup output".into());
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -650,19 +632,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
@@ -733,19 +715,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
@@ -818,12 +800,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
) -> Result<(), CircuitError>
where
F: Field,
{
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
return Err(CircuitError::WrongColumnType(input.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
@@ -914,7 +896,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
@@ -935,7 +917,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
@@ -955,7 +937,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)
}
}

94
src/circuit/ops/errors.rs Normal file
View File

@@ -0,0 +1,94 @@
use std::convert::Infallible;
use crate::{fieldutils::IntegerRep, tensor::TensorError};
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
/// Error type for the circuit module
#[derive(Error, Debug)]
pub enum CircuitError {
/// Halo 2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] PlonkError),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] TensorError),
/// Shape mismatch in circuit construction
#[error("dimension mismatch in circuit construction for op: {0}")]
DimMismatch(String),
/// Error when instantiating lookup tables
#[error("failed to instantiate lookup tables")]
LookupInstantiation,
/// A lookup table was was already assigned
#[error("attempting to initialize an already instantiated lookup table")]
TableAlreadyAssigned,
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
#[error("failed to flush, linear coord is not aligned with the next row")]
FlushError,
/// Constrain error
#[error("constrain_equal: one of the tensors is assigned and the other is not")]
ConstrainError,
/// Failed to get lookups
#[error("failed to get lookups for op: {0}")]
GetLookupsError(String),
/// Failed to get range checks
#[error("failed to get range checks for op: {0}")]
GetRangeChecksError(String),
/// Failed to get dynamic lookup
#[error("failed to get dynamic lookup for op: {0}")]
GetDynamicLookupError(String),
/// Failed to get shuffle
#[error("failed to get shuffle for op: {0}")]
GetShuffleError(String),
/// Failed to get constants
#[error("failed to get constants for op: {0}")]
GetConstantsError(String),
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Invalid min/max lookup range
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
InvalidMinMaxRange(IntegerRep, IntegerRep),
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
/// Mismatched lookup length
#[error("mismatched lookup lengths: {0} and {1}")]
MismatchedLookupLength(usize, usize),
/// Mismatched shuffle length
#[error("mismatched shuffle lengths: {0} and {1}")]
MismatchedShuffleLength(usize, usize),
/// Mismatched lookup table lengths
#[error("mismatched lookup table lengths: {0} and {1}")]
MismatchedLookupTableLength(usize, usize),
/// Wrong column type for lookup
#[error("wrong column type for lookup: {0}")]
WrongColumnType(String),
/// Wrong column type for dynamic lookup
#[error("wrong column type for dynamic lookup: {0}")]
WrongDynamicColumnType(String),
/// Missing selectors
#[error("missing selectors for op: {0}")]
MissingSelectors(String),
/// Table lookup error
#[error("value ({0}) out of range: ({1}, {2})")]
TableOOR(IntegerRep, IntegerRep, IntegerRep),
/// Loookup not configured
#[error("lookup not configured: {0}")]
LookupNotConfigured(String),
/// Range check not configured
#[error("range check not configured: {0}")]
RangeCheckNotConfigured(String),
/// Missing layout
#[error("missing layout for op: {0}")]
MissingLayout(String),
}

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::i128_to_felt,
fieldutils::integer_rep_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -71,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -155,7 +155,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::SumPool {
padding,
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
i128_to_felt(input_scale.0 as i128),
i128_to_felt(output_scale.0 as i128),
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
)?
} else {
layouts::nonlinearity(
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
integer_rep_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
@@ -287,7 +287,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,9 @@
use super::*;
use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -132,7 +131,7 @@ impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
let range = range as IntegerRep;
(-range, range)
}
@@ -141,7 +140,7 @@ impl LookupOp {
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
@@ -228,7 +227,7 @@ impl LookupOp {
}
}?;
let output = res.map(|x| i128_to_felt(x));
let output = res.map(|x| integer_rep_to_felt(x));
Ok(ForwardResult { output })
}
@@ -295,7 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,
region,
@@ -305,7 +304,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
}
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];

View File

@@ -1,4 +1,4 @@
use std::{any::Any, error::Error};
use std::any::Any;
use serde::{Deserialize, Serialize};
@@ -15,6 +15,8 @@ pub mod base;
///
pub mod chip;
///
pub mod errors;
///
pub mod hybrid;
/// Layouts for specific functions (composed of base ops)
pub mod layouts;
@@ -25,6 +27,8 @@ pub mod poly;
///
pub mod region;
pub use errors::CircuitError;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
@@ -44,10 +48,10 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>>;
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>>;
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError>;
/// Do any of the inputs to this op require homogenous input scales?
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
@@ -122,8 +126,8 @@ impl InputType {
*input = T::from_f64(f64_input).unwrap();
}
InputType::Int | InputType::TDim => {
let int_input = input.clone().to_i128().unwrap();
*input = T::from_i128(int_input).unwrap();
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
}
}
@@ -139,7 +143,7 @@ pub struct Input {
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.scale)
}
@@ -156,7 +160,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
match self.datum_type {
@@ -194,7 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
@@ -209,8 +213,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
Err(Box::new(super::CircuitError::UnsupportedOp))
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
@@ -240,7 +244,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box<dyn Error>> {
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = self.quantized_values.visibility().unwrap();
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
@@ -278,7 +282,7 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
@@ -292,7 +296,7 @@ impl<
Box::new(self.clone()) // Forward to the derive(Clone) impl
}
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.quantized_values.scale().unwrap())
}

View File

@@ -33,6 +33,7 @@ pub enum PolyOp {
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
},
Downsample {
axis: usize,
@@ -43,6 +44,7 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
},
Add,
Sub,
@@ -97,7 +99,8 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
+ for<'de> Deserialize<'de>
,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
@@ -147,17 +150,25 @@ impl<
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
PolyOp::Conv {
stride,
padding,
group,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
group,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -178,7 +189,7 @@ impl<
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
@@ -211,9 +222,18 @@ impl<
PolyOp::Prod { axes, .. } => {
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::Conv {
padding,
stride,
group,
} => layouts::conv(
config,
region,
values[..].try_into()?,
padding,
stride,
*group,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -260,6 +280,7 @@ impl<
padding,
output_padding,
stride,
group,
} => layouts::deconv(
config,
region,
@@ -267,6 +288,7 @@ impl<
padding,
output_padding,
stride,
*group,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -277,9 +299,10 @@ impl<
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
return Err(Box::new(TensorError::DimError(
return Err(TensorError::DimError(
"Pad operation requires a single input".to_string(),
)));
)
.into());
}
let mut input = values[0].clone();
input.pad(p.clone(), 0)?;
@@ -296,7 +319,7 @@ impl<
}))
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,

View File

@@ -1,6 +1,7 @@
use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
fieldutils::IntegerRep,
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
@@ -9,7 +10,8 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
@@ -19,7 +21,7 @@ use std::{
},
};
use super::lookup::LookupOp;
use super::{lookup::LookupOp, CircuitError};
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
@@ -84,43 +86,77 @@ impl ShuffleIndex {
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
/// wrap other regions
#[error("Wrapped region: {0}")]
Wrapped(String),
#[derive(Debug, Clone)]
/// Some settings for a region to differentiate it across the different phases of proof generation
pub struct RegionSettings {
/// whether we are in witness generation mode
pub witness_gen: bool,
/// whether we should check range checks for validity
pub check_range: bool,
}
impl From<String> for RegionError {
fn from(e: String) -> Self {
Self::Wrapped(e)
#[allow(unsafe_code)]
unsafe impl Sync for RegionSettings {}
#[allow(unsafe_code)]
unsafe impl Send for RegionSettings {}
impl RegionSettings {
/// Create a new region settings
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
RegionSettings {
witness_gen,
check_range,
}
}
/// Create a new region settings with all true
pub fn all_true() -> RegionSettings {
RegionSettings {
witness_gen: true,
check_range: true,
}
}
/// Create a new region settings with all false
pub fn all_false() -> RegionSettings {
RegionSettings {
witness_gen: false,
check_range: false,
}
}
}
impl From<&str> for RegionError {
fn from(e: &str) -> Self {
Self::Wrapped(e.to_string())
#[derive(Debug, Default, Clone)]
/// Region statistics
pub struct RegionStatistics {
/// the current maximum value of the lookup inputs
pub max_lookup_inputs: IntegerRep,
/// the current minimum value of the lookup inputs
pub min_lookup_inputs: IntegerRep,
/// the current maximum value of the range size
pub max_range_size: IntegerRep,
/// the current set of used lookups
pub used_lookups: HashSet<LookupOp>,
/// the current set of used range checks
pub used_range_checks: HashSet<Range>,
}
impl RegionStatistics {
/// update the statistics with another set of statistics
pub fn update(&mut self, other: &RegionStatistics) {
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
self.max_range_size = self.max_range_size.max(other.max_range_size);
self.used_lookups.extend(other.used_lookups.clone());
self.used_range_checks
.extend(other.used_range_checks.clone());
}
}
impl From<TensorError> for RegionError {
fn from(e: TensorError) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Error> for RegionError {
fn from(e: Error) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
impl From<Box<dyn std::error::Error>> for RegionError {
fn from(e: Box<dyn std::error::Error>) -> Self {
Self::Wrapped(format!("{:?}", e))
}
}
#[allow(unsafe_code)]
unsafe impl Sync for RegionStatistics {}
#[allow(unsafe_code)]
unsafe impl Send for RegionStatistics {}
#[derive(Debug)]
/// A context for a region
@@ -131,12 +167,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
witness_gen: bool,
statistics: RegionStatistics,
settings: RegionSettings,
assigned_constants: ConstantsMap<F>,
}
@@ -188,7 +220,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn witness_gen(&self) -> bool {
self.witness_gen
self.settings.witness_gen
}
///
pub fn check_range(&self) -> bool {
self.settings.check_range
}
///
pub fn statistics(&self) -> &RegionStatistics {
&self.statistics
}
/// Create a new region context
@@ -203,12 +245,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: true,
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(),
assigned_constants: HashMap::new(),
}
}
@@ -224,34 +262,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
region,
num_inner_cols,
linear_coord,
row,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
settings: RegionSettings,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -262,12 +279,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
}
}
@@ -277,7 +290,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row: usize,
linear_coord: usize,
num_inner_cols: usize,
witness_gen: bool,
settings: RegionSettings,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -287,12 +300,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
}
}
@@ -301,10 +310,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn apply_in_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
if self.is_dummy() {
self.dummy_loop(output, inner_loop_function)?;
} else {
@@ -317,8 +326,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn real_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>,
) -> Result<(), RegionError> {
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>,
) -> Result<(), CircuitError> {
output
.iter_mut()
.enumerate()
@@ -326,7 +335,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
*o = inner_loop_function(i, self)?;
Ok(())
})
.collect::<Result<Vec<_>, RegionError>>()?;
.collect::<Result<Vec<_>, CircuitError>>()?;
Ok(())
}
@@ -337,114 +346,71 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
) -> Result<(), RegionError> {
) -> Result<(), CircuitError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
*output = output.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.witness_gen,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.settings.clone(),
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
linear_coord.fetch_add(
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
// update the lookups
let mut statistics = statistics.lock().unwrap();
statistics.update(local_reg.statistics());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
res
})?;
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
self.statistics = Arc::try_unwrap(statistics)
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
})?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?;
Ok(())
}
@@ -453,29 +419,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
return Err(CircuitError::InvalidMinMaxRange(range.0, range.1));
}
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
Ok(())
}
@@ -489,14 +452,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
self.used_lookups.insert(lookup);
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
self.used_range_checks.insert(range);
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
self.statistics.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
@@ -537,27 +500,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
self.statistics.used_lookups.clone()
}
/// get used range checks
pub fn used_range_checks(&self) -> HashSet<Range> {
self.used_range_checks.clone()
self.statistics.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
self.max_lookup_inputs
pub fn max_lookup_inputs(&self) -> IntegerRep {
self.statistics.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
self.min_lookup_inputs
pub fn min_lookup_inputs(&self) -> IntegerRep {
self.statistics.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i128 {
self.max_range_size
pub fn max_range_size(&self) -> IntegerRep {
self.statistics.max_range_size
}
/// Assign a valtensor to a vartensor
@@ -565,18 +528,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
@@ -592,18 +555,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
@@ -614,7 +577,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
self.assign_dynamic_lookup(var, values)
}
@@ -623,27 +586,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<&usize>,
) -> Result<ValTensor<F>, Error> {
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign_with_omissions(
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
)?)
} else {
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
@@ -690,7 +650,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// constrain equal
pub fn constrain_equal(&mut self, a: &ValTensor<F>, b: &ValTensor<F>) -> Result<(), Error> {
pub fn constrain_equal(
&mut self,
a: &ValTensor<F>,
b: &ValTensor<F>,
) -> Result<(), CircuitError> {
if let Some(region) = &self.region {
let a = a.get_inner_tensor().unwrap();
let b = b.get_inner_tensor().unwrap();
@@ -700,12 +664,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let b = b.get_prev_assigned();
// if they're both assigned, we can constrain them
if let (Some(a), Some(b)) = (&a, &b) {
region.borrow_mut().constrain_equal(a.cell(), b.cell())
region
.borrow_mut()
.constrain_equal(a.cell(), b.cell())
.map_err(|e| e.into())
} else if a.is_some() || b.is_some() {
log::error!(
"constrain_equal: one of the tensors is assigned and the other is not"
);
return Err(Error::Synthesis);
return Err(CircuitError::ConstrainError);
} else {
Ok(())
}
@@ -731,7 +695,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// flush row to the next row
pub fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
pub fn flush(&mut self) -> Result<(), CircuitError> {
// increment by the difference between the current linear coord and the next row
let remainder = self.linear_coord % self.num_inner_cols;
if remainder != 0 {
@@ -739,7 +703,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.increment(diff);
}
if self.linear_coord % self.num_inner_cols != 0 {
return Err("flush: linear coord is not aligned with the next row".into());
return Err(CircuitError::FlushError);
}
Ok(())
}

View File

@@ -1,4 +1,4 @@
use std::{error::Error, marker::PhantomData};
use std::marker::PhantomData;
use halo2curves::ff::PrimeField;
@@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
circuit::CircuitError,
fieldutils::i128_to_felt,
fieldutils::{integer_rep_to_felt, IntegerRep},
tensor::{Tensor, TensorType},
};
use crate::circuit::lookup::LookupOp;
/// The range of the lookup table.
pub type Range = (i128, i128);
pub type Range = (IntegerRep, IntegerRep);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: IntegerRep = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
@@ -100,17 +100,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
i128_to_felt(chunk)
integer_rep_to_felt(chunk)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
let chunk = chunk as i128;
let chunk = chunk as IntegerRep;
// we index from 1 to prevent soundness issues
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
let first_element =
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
@@ -130,9 +131,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}
///
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
(range_len / (col_size as IntegerRep)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
@@ -194,15 +195,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
&mut self,
layouter: &mut impl Layouter<F>,
preassigned_input: bool,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
@@ -275,9 +276,9 @@ pub struct RangeCheck<F: PrimeField> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
let chunk = chunk as IntegerRep;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
}
///
@@ -293,10 +294,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
i128_to_felt(chunk)
integer_rep_to_felt(chunk)
}
}
@@ -342,15 +343,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
return Err(CircuitError::TableAlreadyAssigned);
}
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -1050,6 +1050,7 @@ mod conv {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1,6 +1,7 @@
use clap::{Parser, Subcommand};
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
@@ -10,7 +11,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
@@ -52,6 +53,8 @@ pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
pub const DEFAULT_VERIFIER_DA_ABI: &str = "verifier_da_abi.json";
/// Default solidity code
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
/// Default calldata path
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
/// Default solidity code for aggregated proofs
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
/// Default solidity code for data attestation
@@ -78,7 +81,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk seperately
/// Default render vk separately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
@@ -253,33 +256,66 @@ lazy_static! {
};
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone, Deserialize, Serialize)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION)]
pub struct Cli {
#[command(subcommand)]
#[allow(missing_docs)]
pub command: Commands,
/// Get the styles for the CLI
pub fn get_styles() -> clap::builder::Styles {
clap::builder::Styles::styled()
.usage(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.header(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.literal(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
)
.invalid(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.error(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.valid(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
)
.placeholder(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
)
}
impl Cli {
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
/// Parse an ezkl configuration from a json
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(arg_json)
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
pub struct Cli {
/// If provided, outputs the completion file for given shell
#[clap(long = "generate", value_parser)]
pub generator: Option<Shell>,
#[command(subcommand)]
#[allow(missing_docs)]
pub command: Option<Commands>,
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -289,8 +325,8 @@ pub enum Commands {
/// Loads model and prints model table
Table {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
@@ -299,30 +335,30 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// Path to output the witness .json file
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS)]
output: PathBuf,
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
output: Option<PathBuf>,
/// Path to the verification key file (optional - solely used to generate kzg commits)
#[arg(short = 'V', long)]
#[arg(short = 'V', long, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// Path to the srs file (optional - solely used to generate kzg commits)
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
},
/// Produces the proving hyperparameters, from run-args
GenSettings {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to generate the circuit settings .json file to
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
@@ -332,51 +368,52 @@ pub enum Commands {
#[cfg(not(target_arch = "wasm32"))]
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to load circuit settings .json file AND overwrite (generated using the gen-settings command).
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)]
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
lookup_safety_margin: i128,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be ceil(2^k * lookup_safety_margin). larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
lookup_safety_margin: f64,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
scales: Option<Vec<crate::Scale>>,
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
#[arg(
long,
value_delimiter = ',',
allow_hyphen_values = true,
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS,
value_hint = clap::ValueHint::Other
)]
scale_rebase_multiplier: Vec<u32>,
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
only_range_check_rebase: Option<bool>,
},
/// Generates a dummy SRS
#[command(name = "gen-srs", arg_required_else_help = true)]
GenSrs {
/// The path to output the generated SRS
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: PathBuf,
/// number of logrows to use for srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
logrows: usize,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -384,79 +421,79 @@ pub enum Commands {
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
#[arg(long)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Commitment used
#[arg(long, default_value = None)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
},
/// Mock aggregate proofs
MockAggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
},
/// setup aggregation circuit :)
/// Setup aggregation circuit and generate pk and vk
SetupAggregate {
/// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
sample_snarks: Vec<PathBuf>,
/// The path to save the desired verification key file to
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
/// Aggregates proofs
Aggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// The path to load the desired proving key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -465,78 +502,79 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum
value_enum,
value_hint = clap::ValueHint::Other
)]
transcript: TranscriptType,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check_mode: CheckMode,
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
/// whether the accumulated proofs are segments of a larger circuit
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
},
/// Creates pk and vk
Setup {
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to output the verification key file to
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the proving key file to
#[arg(long, default_value = DEFAULT_PK)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The graph witness (optional - used to override fixed values in the circuit)
#[arg(short = 'W', long)]
#[arg(short = 'W', long, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long)]
data: PathBuf,
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information
/// derived from the file information in the data .json file.
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long)]
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// where the input data come from
#[arg(long, default_value = "on-chain")]
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
/// where the output data come from
#[arg(long, default_value = "on-chain")]
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -544,121 +582,139 @@ pub enum Commands {
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long)]
data: PathBuf,
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
#[arg(short = 'P', long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the witness file
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness_path: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Loads model, data, and creates proof
Prove {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to load the desired proving key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_PK)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = ProofType::Single,
value_enum
value_enum,
value_hint = clap::ValueHint::Other
)]
proof_type: ProofType,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check_mode: CheckMode,
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to save the calldata to
#[arg(long, default_value = DEFAULT_CALLDATA, value_hint = clap::ValueHint::FilePath)]
calldata_path: Option<PathBuf>,
/// The path to the verification key address (only used if the vk is rendered as a separate contract)
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VK_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// The path to the .json data file, which should
/// contain the necessary calldata and account addresses
/// needed to read from all the on-chain
/// view functions that return the data that the network
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -666,104 +722,104 @@ pub enum Commands {
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_settings: Vec<PathBuf>,
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
},
/// Verifies a proof, returning accept or reject
Verify {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the verification key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the verification key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVerifier {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVK {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -771,25 +827,25 @@ pub enum Commands {
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution)
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -797,19 +853,39 @@ pub enum Commands {
#[command(name = "verify-evm")]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(feature = "no-update"))]
/// Updates ezkl binary to version specified (or latest if not specified)
Update {
/// The version to update to
#[arg(value_hint = clap::ValueHint::Other, short='v', long)]
version: Option<String>,
},
}
impl Commands {
/// Converts the commands to a json string
pub fn as_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
/// Converts a json string to a Commands struct
pub fn from_json(json: &str) -> Self {
serde_json::from_str(json).unwrap()
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -2,42 +2,21 @@ use halo2_proofs::arithmetic::Field;
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
use halo2curves::ff::PrimeField;
/// Converts an i32 to a PrimeField element.
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
if x >= 0 {
F::from(x as u64)
} else {
-F::from(x.unsigned_abs() as u64)
}
}
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an i128 to a PrimeField element.
pub fn i128_to_felt<F: PrimeField>(x: i128) -> F {
/// Converts an i64 to a PrimeField element.
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
-F::from_u128((-x) as u128)
}
}
/// Converts a PrimeField element to an i32.
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
if x > F::from(i32::MAX as u64) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
-(lower_32 as i32)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
lower_32 as i32
-F::from_u128(x.saturating_neg() as u128)
}
}
/// Converts a PrimeField element to an f64.
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
if x > F::from_u128(i128::MAX as u128) {
if x > F::from_u128(IntegerRep::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -50,18 +29,18 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
}
}
/// Converts a PrimeField element to an i128.
pub fn felt_to_i128<F: PrimeField + PartialOrd + Field>(x: F) -> i128 {
if x > F::from_u128(i128::MAX as u128) {
/// Converts a PrimeField element to an i64.
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
-(lower_128 as i128)
-(lower_128 as IntegerRep)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
lower_128 as i128
lower_128 as IntegerRep
}
}
@@ -73,33 +52,24 @@ mod test {
#[test]
fn test_conv() {
let res: F = i32_to_felt(-15i32);
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
let res: F = i32_to_felt(2_i32.pow(17));
let res: F = integer_rep_to_felt(2_i128.pow(17));
assert_eq!(res, F::from(131072));
let res: F = i128_to_felt(-15i128);
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
let res: F = i128_to_felt(2_i128.pow(17));
let res: F = integer_rep_to_felt(2_i128.pow(17));
assert_eq!(res, F::from(131072));
}
#[test]
fn felttoi32() {
for x in -(2i32.pow(16))..(2i32.pow(16)) {
let fieldx: F = i32_to_felt::<F>(x);
let xf: i32 = felt_to_i32::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttoi128() {
for x in -(2i128.pow(20))..(2i128.pow(20)) {
let fieldx: F = i128_to_felt::<F>(x);
let xf: i128 = felt_to_i128::<F>(fieldx);
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}

137
src/graph/errors.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::convert::Infallible;
use thiserror::Error;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported datatype in graph node {0} ({1})")]
UnsupportedDataType(usize, String),
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Reading a file failed
#[error("[io] ({0}) {1}")]
ReadWriteFileError(String, String),
/// Model serialization error
#[error("failed to ser/deser model: {0}")]
ModelSerialize(#[from] bincode::Error),
/// Tract error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tract] {0}")]
TractError(#[from] tract_onnx::prelude::TractError),
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
/// Public visibility for params is deprecated
#[error("public visibility for params is deprecated, please use `fixed` instead")]
ParamsPublicVisibility,
/// Slice length mismatch
#[error("slice length mismatch: {0}")]
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
/// Bad conversion
#[error("invalid conversion: {0}")]
InvalidConversion(#[from] Infallible),
/// Circuit error
#[error("[circuit] {0}")]
CircuitError(#[from] crate::circuit::CircuitError),
/// Halo2 error
#[error("[halo2] {0}")]
Halo2Error(#[from] halo2_proofs::plonk::Error),
/// System time error
#[error("[system time] {0}")]
SystemTimeError(#[from] std::time::SystemTimeError),
/// Missing Batch Size
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
MissingBatchSize,
/// Tokio postgres error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[eth] {0}")]
EthError(#[from] crate::eth::EthError),
/// Json error
#[error("[json] {0}")]
JsonError(#[from] serde_json::Error),
/// Missing instances
#[error("missing instances")]
MissingInstances,
/// Missing constants
#[error("missing constants")]
MissingConstants,
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
///
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
///
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
///
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
#[error("missing scale")]
MissingScale,
/// Extended k is too large
#[error("extended k is too large to accommodate the quotient polynomial with logrows {0}")]
ExtendedKTooLarge(u32),
/// Max lookup input is too large
#[error("lookup range {0} is too large")]
LookupRangeTooLarge(usize),
/// Max range check input is too large
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]
MissingDataSource,
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
}

View File

@@ -1,13 +1,13 @@
use super::errors::GraphError;
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::postgres::Client;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use postgres::{Client, NoTls};
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
@@ -128,7 +128,7 @@ impl FileSourceInner {
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
@@ -150,7 +150,7 @@ impl FileSourceInner {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64,
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
@@ -211,7 +211,7 @@ impl PostgresSource {
}
/// Fetch data from postgres
pub fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
@@ -232,10 +232,10 @@ impl PostgresSource {
)
};
let mut client = Client::connect(&config, NoTls)?;
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[])? {
for row in client.query(&query, &[]).await? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
@@ -245,11 +245,10 @@ impl PostgresSource {
}
/// Fetch data from postgres and format it as a FileSource
pub fn fetch_and_format_as_file(
&self,
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()?
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
@@ -276,14 +275,14 @@ impl OnChainSource {
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data};
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
use crate::eth::{
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
};
use log::debug;
// Set up local anvil instance for reading on-chain data
let (anvil, client) = crate::eth::setup_eth_backend(rpc, None).await?;
let address = client.address();
let (client, client_address) = crate::eth::setup_eth_backend(rpc, None).await?;
let mut scales = scales;
// set scales to 1 where data is a field element
@@ -296,7 +295,8 @@ impl OnChainSource {
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
debug!("Calls to accounts: {:?}", calls_to_accounts);
let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?;
let inputs =
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
debug!("Inputs: {:?}", inputs);
let mut quantized_evm_inputs = vec![];
@@ -325,7 +325,7 @@ impl OnChainSource {
inputs.push(t);
}
let used_rpc = rpc.unwrap_or(&anvil.endpoint()).to_string();
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
// Fill the input_data field of the GraphData struct
Ok((
@@ -451,7 +451,7 @@ impl GraphData {
&self,
shapes: &[Vec<usize>],
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
) -> Result<TVec<TValue>, GraphError> {
let mut inputs = TVec::new();
match &self.input_data {
DataSource::File(data) => {
@@ -466,10 +466,10 @@ impl GraphData {
}
}
_ => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
)))
))
}
}
Ok(inputs)
@@ -484,28 +484,35 @@ impl GraphData {
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let reader = std::fs::File::open(path)?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
reader.read_to_string(&mut buf).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, self)?;
Ok(())
}
///
pub fn split_into_batches(
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
) -> Result<Vec<Self>, GraphError> {
// split input data into batches
let mut batched_inputs = vec![];
@@ -518,16 +525,16 @@ impl GraphData {
input_data: DataSource::OnChain(_),
output_data: _,
} => {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
)))
))
}
#[cfg(not(target_arch = "wasm32"))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file()?,
} => data.fetch_and_format_as_file().await?,
};
for (i, shape) in input_shapes.iter().enumerate() {
@@ -535,11 +542,11 @@ impl GraphData {
let input_size = shape.clone().iter().product::<usize>();
let input = &iterable[i];
if input.len() % input_size != 0 {
return Err(Box::new(GraphError::InvalidDims(
return Err(GraphError::InvalidDims(
0,
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
)));
));
}
let mut batches = vec![];
for batch in input.chunks(input_size) {

View File

@@ -6,10 +6,17 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(not(target_arch = "wasm32"))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
pub mod vars;
/// errors for the graph
pub mod errors;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
@@ -20,16 +27,17 @@ pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
use self::errors::GraphError;
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSource;
use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
@@ -54,7 +62,6 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
pub use vars::*;
@@ -62,13 +69,14 @@ pub use vars::*;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: IntegerRep = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: IntegerRep =
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -84,62 +92,6 @@ lazy_static! {
#[cfg(target_arch = "wasm32")]
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// The wrong inputs were passed to a lookup node
#[error("invalid inputs for a lookup node")]
InvalidLookupInputs,
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
/// A requested node is missing in the graph
#[error("a requested node is missing in the graph: {0}")]
MissingNode(usize),
/// The wrong method was called on an operation
#[error("an unsupported method was called on node {0} ({1})")]
OpMismatch(usize, String),
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
/// This operation is unsupported
#[error("unsupported datatype in graph")]
UnsupportedDataType,
/// A node has missing parameters
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
Visibility,
/// Ezkl only supports divisions by constants
#[error("ezkl currently only supports division by constants")]
NonConstantDiv,
/// Ezkl only supports constant powers
#[error("ezkl currently only supports constant exponents")]
NonConstantPower,
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Error when attempting to load a model
#[error("failed to load")]
ModelLoad,
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,
/// Invalid Input Types
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
}
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
@@ -175,11 +127,11 @@ pub struct GraphWitness {
/// Any hashes of outputs generated during the forward pass
pub processed_outputs: Option<ModuleForwardResult>,
/// max lookup input
pub max_lookup_inputs: i128,
pub max_lookup_inputs: IntegerRep,
/// max lookup input
pub min_lookup_inputs: i128,
pub min_lookup_inputs: IntegerRep,
/// max range check size
pub max_range_size: i128,
pub max_range_size: IntegerRep,
}
impl GraphWitness {
@@ -306,30 +258,31 @@ impl GraphWitness {
}
/// Export the ezkl witness as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load {}", path.display()))?;
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let file = std::fs::File::open(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
@@ -591,11 +544,11 @@ impl GraphSettings {
}
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
return Err(e.into());
}
};
Ok(serialized)
@@ -691,17 +644,21 @@ impl GraphCircuit {
&self.core.model
}
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let f = std::fs::File::open(path)?;
let f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
@@ -766,10 +723,7 @@ pub struct TestOnChainData {
impl GraphCircuit {
///
pub fn new(
model: Model,
run_args: &RunArgs,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
pub fn new(model: Model, run_args: &RunArgs) -> Result<GraphCircuit, GraphError> {
// // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -816,7 +770,7 @@ impl GraphCircuit {
model: Model,
mut settings: GraphSettings,
check_mode: CheckMode,
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
) -> Result<GraphCircuit, GraphError> {
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
let mut inputs: Vec<Vec<Fp>> = vec![];
for shape in model.graph.input_shapes()? {
@@ -840,35 +794,37 @@ impl GraphCircuit {
}
/// load inputs and outputs for the model
pub fn load_graph_witness(
&mut self,
data: &GraphWitness,
) -> Result<(), Box<dyn std::error::Error>> {
pub fn load_graph_witness(&mut self, data: &GraphWitness) -> Result<(), GraphError> {
self.graph_witness = data.clone();
// load the module settings
Ok(())
}
/// Prepare the public inputs for the circuit.
pub fn prepare_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Vec<Fp>, Box<dyn std::error::Error>> {
pub fn prepare_public_inputs(&self, data: &GraphWitness) -> Result<Vec<Fp>, GraphError> {
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
let mut public_inputs: Vec<Fp> = vec![];
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
} else if let Some(processed_inputs) = &data.processed_inputs {
// we first process the inputs
if let Some(processed_inputs) = &data.processed_inputs {
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
}
// we then process the params
if let Some(processed_params) = &data.processed_params {
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
}
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
}
// if the outputs are public, we add them to the public inputs
if self.settings().run_args.output_visibility.is_public() {
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
// if the outputs are processed, we add the processed outputs to the public inputs
} else if let Some(processed_outputs) = &data.processed_outputs {
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
}
@@ -886,7 +842,7 @@ impl GraphCircuit {
pub fn pretty_public_inputs(
&self,
data: &GraphWitness,
) -> Result<Option<PrettyElements>, Box<dyn std::error::Error>> {
) -> Result<Option<PrettyElements>, GraphError> {
// dequantize the supplied data using the provided scale.
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
@@ -928,10 +884,7 @@ impl GraphCircuit {
///
#[cfg(target_arch = "wasm32")]
pub fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -942,7 +895,7 @@ impl GraphCircuit {
pub fn load_graph_from_file_exclusively(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
@@ -952,22 +905,23 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => Err("Cannot use non-file data source as input for this method.".into()),
_ => unreachable!("cannot load from on-chain data"),
}
}
///
#[cfg(not(target_arch = "wasm32"))]
pub fn load_graph_input(
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(target_arch = "wasm32")]
@@ -978,26 +932,24 @@ impl GraphCircuit {
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::OnChain(_) => {
Err("Cannot use on-chain data source as input for this method.".into())
}
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
fn process_data_source(
async fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::OnChain(source) => {
let mut per_item_scale = vec![];
@@ -1005,21 +957,14 @@ impl GraphCircuit {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
// start runtime and fetch data
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async {
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
})
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
}
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file()?;
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
@@ -1032,10 +977,10 @@ impl GraphCircuit {
source: OnChainSource,
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (_, client) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
// quantize the supplied data using the provided scale + QuantizeData.sol
let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
@@ -1056,7 +1001,7 @@ impl GraphCircuit {
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (((d, shape), scale), input_type) in file_data
@@ -1087,7 +1032,7 @@ impl GraphCircuit {
&mut self,
file_data: &[Vec<Fp>],
shapes: &[Vec<usize>],
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
) -> Result<Vec<Tensor<Fp>>, GraphError> {
// quantize the supplied data using the provided scale.
let mut data: Vec<Tensor<Fp>> = vec![];
for (d, shape) in file_data.iter().zip(shapes) {
@@ -1098,14 +1043,14 @@ impl GraphCircuit {
Ok(data)
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
@@ -1113,8 +1058,8 @@ impl GraphCircuit {
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
) -> Result<u32, Box<dyn std::error::Error>> {
max_range_size: IntegerRep,
) -> Result<u32, GraphError> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
@@ -1132,10 +1077,10 @@ impl GraphCircuit {
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: i128,
max_range_size: IntegerRep,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
lookup_safety_margin: f64,
) -> Result<(), GraphError> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
@@ -1144,15 +1089,22 @@ impl GraphCircuit {
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if subtraction overflows
let lookup_size =
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
// check if has overflowed max lookup input
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
return Err(GraphError::LookupRangeTooLarge(
lookup_size.unsigned_abs() as usize
));
}
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
return Err(GraphError::RangeCheckTooLarge(
max_range_size.unsigned_abs() as usize,
));
}
// These are hard lower limits, we can't overflow instances or modules constraints
@@ -1196,12 +1148,7 @@ impl GraphCircuit {
}
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
return Err(GraphError::ExtendedKTooLarge(max_logrows));
}
let logrows = max_logrows;
@@ -1228,7 +1175,7 @@ impl GraphCircuit {
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
max_range_size: IntegerRep,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
@@ -1286,8 +1233,8 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
witness_gen: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
region_settings: RegionSettings,
) -> Result<GraphWitness, GraphError> {
let original_inputs = inputs.to_vec();
let visibility = VarVisibility::from_args(&self.settings().run_args)?;
@@ -1335,7 +1282,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, witness_gen)?;
.forward(inputs, &self.settings().run_args, region_settings)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1402,7 +1349,7 @@ impl GraphCircuit {
pub fn from_run_args(
run_args: &RunArgs,
model_path: &std::path::Path,
) -> Result<Self, Box<dyn std::error::Error>> {
) -> Result<Self, GraphError> {
let model = Model::from_run_args(run_args, model_path)?;
Self::new(model, run_args)
}
@@ -1413,8 +1360,11 @@ impl GraphCircuit {
params: &GraphSettings,
model_path: &std::path::Path,
check_mode: CheckMode,
) -> Result<Self, Box<dyn std::error::Error>> {
params.run_args.validate()?;
) -> Result<Self, GraphError> {
params
.run_args
.validate()
.map_err(GraphError::InvalidRunArgs)?;
let model = Model::from_run_args(&params.run_args, model_path)?;
Self::new_from_settings(model, params.clone(), check_mode)
}
@@ -1425,7 +1375,7 @@ impl GraphCircuit {
&mut self,
data: &mut GraphData,
test_on_chain_data: TestOnChainData,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<(), GraphError> {
// Set up local anvil instance for reading on-chain data
let input_scales = self.model().graph.get_input_scales();
@@ -1439,15 +1389,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.input_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
_ => {
return Err("Cannot use non file source as input for on-chain test.
Manually populate on-chain data from file source instead"
.into())
return Err(GraphError::OnChainDataSource);
}
};
// Get the flatten length of input_data
@@ -1468,19 +1416,13 @@ impl GraphCircuit {
) {
// if not public then fail
if self.settings().run_args.output_visibility.is_private() {
return Err("Cannot use on-chain data source as private data".into());
return Err(GraphError::OnChainDataSource);
}
let output_data = match &data.output_data {
Some(DataSource::File(output_data)) => output_data,
Some(DataSource::OnChain(_)) => {
return Err(
"Cannot use on-chain data source as output for on-chain test.
Will manually populate on-chain data from file source instead"
.into(),
)
}
_ => return Err("No output data found".into()),
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
_ => return Err(GraphError::MissingDataSource),
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
@@ -1523,12 +1465,10 @@ impl CircuitSize {
#[cfg(not(target_arch = "wasm32"))]
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
pub fn as_json(&self) -> Result<String, GraphError> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
Err(e) => return Err(e.into()),
};
Ok(serialized)
}

Some files were not shown because too many files have changed in this diff Show More