Compare commits

...

38 Commits

Author SHA1 Message Date
dante
4a172877af feat: add autocompletion tooling to cli (#805) 2024-05-30 16:44:24 -04:00
dante
5a8498894d feat: create encode-evm-calldata in cli and python (#803) 2024-05-29 21:03:16 -04:00
Snoppy
095c0ca5b4 chore: fix typos (#802) 2024-05-29 11:32:11 -04:00
dante
3fa482c9ef fix: remove error messages in calibration loop (confusing) (#800) 2024-05-22 09:32:52 +01:00
dante
6be3b1d663 chore: update alloy (#798) 2024-05-16 12:10:27 +09:00
dante
d5a1d1439c fix: elif syntax in install script 2024-05-14 03:31:26 +09:00
Ethan Cemer
ff8fd01f86 fix: patch invalid ${{ github.ref_name#v }} syntax on verify release script (#791) 2024-05-14 02:25:12 +09:00
dante
e9020f942e fix: python build matrix (#797) 2024-05-13 13:29:51 +09:00
dante
e7f54cb6ac fix: windows build (#796) 2024-05-13 12:47:36 +09:00
dante
ed65e8c090 Revert "fix: revert maturin version (#795)"
This reverts commit d9f2adad99.
2024-05-13 12:41:02 +09:00
dante
d9f2adad99 fix: revert maturin version (#795) 2024-05-13 12:21:35 +09:00
dante
5125aaa090 chore: add aarch64 linux to release pipeline (#788) 2024-05-13 11:29:49 +09:00
Jseam
f1950e6cd0 docs: add polycommit to RunArgs (#794) 2024-05-10 22:19:05 +09:00
dante
998ca22c2a chore: make most command struct args Option (#793) 2024-05-10 22:02:34 +09:00
dante
5c574adc31 chore: logistic regression example (#792) 2024-05-08 20:30:13 +09:00
dante
749e0ba652 chore: update h2 solidity verifier (#787) 2024-05-03 01:25:14 +01:00
dante
d464ddf6b6 chore: medium sized lstm example (#785) 2024-05-01 16:35:11 +01:00
dante
8f6c0aced5 chore: update tract (#784) 2024-04-30 13:31:33 +01:00
dante
860e9700a8 refactor!: swap integer rep to i64 from i128 (#781)
BREAKING CHANGE: may break w/ old compiled circuits
2024-04-26 16:16:55 -04:00
Ethan Cemer
32dd4a854f fix: patch npm package build failure (#782) 2024-04-25 10:38:06 -04:00
dante
924f7c0420 fix: simplify kzg-commit (#780) 2024-04-24 11:57:20 -04:00
dante
ae03b6515b fix: update vis settings on help (#779) 2024-04-22 16:23:19 -04:00
Ethan Cemer
bae2e9e22b feat: kzgCommit wasm method (#778) 2024-04-18 19:22:11 -04:00
dante
4a93d31869 fix: accomodate modules in col-overflow (#777) 2024-04-18 17:13:31 -04:00
dante
88dd83dbe5 fix: default compiled model paths in python (#776) 2024-04-15 12:01:21 -04:00
Ethan Cemer
f05f83481e chore: update eth postgres (#769)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-04-13 08:08:09 -04:00
Ethan Cemer
8aaf518b5e fix: fix @ezkljs/verify etherumjs deps (#765) 2024-04-12 18:24:59 -04:00
katsumata
1b7b43e073 fix: Improve EZKL installation script reliability (#774) 2024-04-09 16:07:39 -04:00
dante
f78618ec59 feat: full ND conv and pool (#770) 2024-04-06 23:29:30 +01:00
Jseam
0943e534ee docs: automated sphinx documentation for python bindings (#714)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
Co-authored-by: Ethan Cemer <tylercemer@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-04-05 18:33:06 +01:00
dante
316a9a3b40 chore: update tract (#766) 2024-04-04 18:07:08 +01:00
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
124 changed files with 11501 additions and 6118 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
name: Build and Publish EZKL Engine npm package
on:
workflow_dispatch:
@@ -62,7 +62,7 @@ jobs:
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
@@ -79,6 +79,10 @@ jobs:
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
- name: Add serialize and deserialize methods to nodejs bundle
run: |
echo '
@@ -174,40 +178,3 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -40,7 +40,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
@@ -50,6 +50,7 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Install built wheel
if: matrix.target == 'universal2-apple-darwin'
run: |
pip install ezkl --no-index --find-links dist --force-reinstall
python -c "import ezkl"
@@ -70,7 +71,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -85,7 +86,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
@@ -110,12 +111,12 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64, i686]
target: [x86_64]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -128,6 +129,7 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -176,7 +178,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -228,7 +230,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -263,7 +265,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -287,7 +289,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -359,3 +361,17 @@ jobs:
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}

View File

@@ -102,28 +102,32 @@ jobs:
PCRE2_SYS_STATIC: 1
strategy:
matrix:
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu]
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu, linux-aarch64]
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2023-06-27
rust: nightly-2024-02-06
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2023-06-27
rust: nightly-2024-02-06
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2023-06-27
rust: nightly-2024-02-06
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2023-06-27
rust: nightly-2024-02-06
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2023-06-27
rust: nightly-2024-02-06
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-02-06
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
@@ -181,7 +185,7 @@ jobs:
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Strip release binary
if: matrix.build != 'windows-msvc'
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"
- name: Strip release binary (Windows)

View File

@@ -184,7 +184,7 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -207,7 +207,7 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -224,7 +224,7 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -281,7 +281,7 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -307,8 +307,8 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
pnpm install --frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -354,7 +354,7 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -380,7 +380,7 @@ jobs:
cache: "pnpm"
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
pnpm install --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -394,14 +394,18 @@ jobs:
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (hashed inputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
- name: IPA prove and verify tests
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests single inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
- name: KZG prove and verify tests triple inner col
@@ -412,8 +416,6 @@ jobs:
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public inputs)
@@ -460,7 +462,7 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -495,7 +497,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -512,7 +514,7 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -545,8 +547,6 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: Download MNIST
run: sh data.sh
- name: Examples
run: cargo nextest run --release tests_examples
@@ -557,31 +557,33 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest -vv
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -592,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -608,11 +610,29 @@ jobs:
python-integration-tests:
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -626,10 +646,16 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +671,3 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -14,6 +14,40 @@ jobs:
- uses: actions/checkout@v4
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.1
uses: mathieudutour/github-tag-action@v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Set Cargo.toml version to match github tag for docs
shell: bash
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
mv docs/python/src/conf.py docs/python/src/conf.py.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
rm docs/python/src/conf.py.orig
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
rm docs/python/requirements-docs.txt.orig
- name: Commit files and create tag
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git fetch --tags
git checkout -b release-$RELEASE_TAG
git add .
git commit -m "ci: update version string in docs"
git tag -d $RELEASE_TAG
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:
branch: release-${{ steps.tag_version.outputs.new_tag }}
force: true
tags: true

68
.github/workflows/verify.yml vendored Normal file
View File

@@ -0,0 +1,68 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"$CLEANED_TAG\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Update pnpm-lock.yaml versions and integrity
run: |
awk -v integrity="$ENGINE_INTEGRITY" -v tag="$CLEANED_TAG" '
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

5
.gitignore vendored
View File

@@ -1,6 +1,5 @@
target
pkg
data
*.csv
!examples/notebooks/eth_price.csv
*.ipynb_checkpoints
@@ -48,4 +47,6 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.1

26
.readthedocs.yaml Normal file
View File

@@ -0,0 +1,26 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: ./docs/python/src/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: ./docs/python/requirements-docs.txt

1953
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,7 @@ halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curv
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.5.3", features = ["derive"] }
clap_complete = "4.5.2"
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
@@ -43,46 +44,49 @@ unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = [
"ethers-solc",
] }
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
ethabi = "18"
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = [
reqwest = { version = "0.12.4", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
futures-util = "0.3.30"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = [
tokio = { version = "1.35", default_features = false, features = [
"macros",
"rt",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.20.2", features = [
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = [
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true }
@@ -95,10 +99,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
@@ -175,7 +179,7 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
default = ["ezkl", "mv-lookup", "no-banner"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = [
@@ -198,6 +202,7 @@ det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']

View File

@@ -91,9 +91,9 @@ You can install the library from source
cargo install --locked --path .
```
You will need a functioning installation of `solc` in order to run `ezkl` properly.
[solc-select](https://github.com/crytic/solc-select) is recommended.
Follow the instructions on [solc-select](https://github.com/crytic/solc-select) to activate `solc` in your environment.
`ezkl` now auto-manages solc installation for you.
#### building python bindings

View File

@@ -20,9 +20,9 @@
"name": "quantize_data",
"outputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"stateMutability": "pure",
@@ -31,9 +31,9 @@
{
"inputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"name": "to_field_element",

View File

@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: [(0, 0); 2],
stride: (1, 1),
padding: vec![(0, 0)],
stride: vec![1; 2],
}),
)
.unwrap();

View File

@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: [(0, 0); 2],
stride: (1, 1),
kernel_shape: (2, 2),
padding: vec![(0, 0); 2],
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
}),
)

View File

@@ -1,6 +1,97 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
import './LoadInstances.sol';
contract LoadInstances {
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
mload(add(add(encoded, add(i, 0x24)), instances_offset))
)
}
}
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from calldata
* @param encoded - verifier calldata
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
instances_length := calldataload(
add(add(encoded.offset, 0x04), instances_offset)
)
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
calldataload(
add(add(encoded.offset, add(i, 0x04)), instances_offset)
)
)
}
}
}
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
@@ -34,11 +125,14 @@ contract DataAttestation is LoadInstances {
address public admin;
/**
* @notice EZKL P value
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER = uint256(0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001);
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_CALLS = 0;
@@ -69,7 +163,7 @@ contract DataAttestation is LoadInstances {
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if(_admin == address(0)) {
if (_admin == address(0)) {
revert();
}
admin = _admin;
@@ -80,7 +174,7 @@ contract DataAttestation is LoadInstances {
bytes[][] memory _callData,
uint256[][] memory _decimals
) external {
require(msg.sender == admin, "Only admin can update instanceOffset");
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
@@ -111,7 +205,10 @@ contract DataAttestation is LoadInstances {
// count the total number of storage reads across all of the accounts
counter += _callData[i].length;
}
require(counter == INPUT_CALLS + OUTPUT_CALLS, "Invalid number of calls");
require(
counter == INPUT_CALLS + OUTPUT_CALLS,
"Invalid number of calls"
);
}
function mulDiv(
@@ -167,7 +264,7 @@ contract DataAttestation is LoadInstances {
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param data - The data returned from the account calls.
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
* @param scale - The scale used to convert the floating point value into a fixed point value.
* @param scale - The scale used to convert the floating point value into a fixed point value.
*/
function quantizeData(
bytes memory data,
@@ -181,7 +278,7 @@ contract DataAttestation is LoadInstances {
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
output += 1;
}
quantized_data = neg ? -int256(output): int256(output);
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
@@ -211,7 +308,9 @@ contract DataAttestation is LoadInstances {
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(int256 x) internal pure returns (uint256 field_element) {
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
@@ -251,12 +350,11 @@ contract DataAttestation is LoadInstances {
}
}
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0,"Address: call to non-contract");
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);

View File

@@ -1,92 +0,0 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
contract LoadInstances {
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
mload(add(add(encoded, add(i, 0x24)), instances_offset))
)
}
}
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from calldata
* @param encoded - verifier calldata
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
instances_length := calldataload(add(add(encoded.offset, 0x04), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly{
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
calldataload(
add(add(encoded.offset, add(i, 0x04)), instances_offset)
)
)
}
}
}
}

View File

@@ -1,135 +0,0 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.17;
contract QuantizeData {
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same instance, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
/**
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or denominator == 0
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv)
* with further edits by Uniswap Labs also under MIT license.
*/
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use
// use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2^256 + prod0.
uint256 prod0; // Least significant 256 bits of the product
uint256 prod1; // Most significant 256 bits of the product
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
// Handle non-overflow cases, 256 by 256 division.
if (prod1 == 0) {
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
// The surrounding unchecked block does not change this fact.
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
return prod0 / denominator;
}
// Make sure the result is less than 2^256. Also prevents denominator == 0.
require(denominator > prod1, "Math: mulDiv overflow");
///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////
// Make division exact by subtracting the remainder from [prod1 prod0].
uint256 remainder;
assembly {
// Compute remainder using mulmod.
remainder := mulmod(x, y, denominator)
// Subtract 256 bit number from 512 bit number.
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
// Factor powers of two out of denominator and compute largest power of two divisor of denominator. Always >= 1.
// See https://cs.stackexchange.com/q/138556/92363.
// Does not overflow because the denominator cannot be zero at this stage in the function.
uint256 twos = denominator & (~denominator + 1);
assembly {
// Divide denominator by twos.
denominator := div(denominator, twos)
// Divide [prod1 prod0] by twos.
prod0 := div(prod0, twos)
// Flip twos such that it is 2^256 / twos. If twos is zero, then it becomes one.
twos := add(div(sub(0, twos), twos), 1)
}
// Shift in bits from prod1 into prod0.
prod0 |= prod1 * twos;
// Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such
// that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for
// four bits. That is, denominator * inv = 1 mod 2^4.
uint256 inverse = (3 * denominator) ^ 2;
// Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works
// in modular arithmetic, doubling the correct bits in each step.
inverse *= 2 - denominator * inverse; // inverse mod 2^8
inverse *= 2 - denominator * inverse; // inverse mod 2^16
inverse *= 2 - denominator * inverse; // inverse mod 2^32
inverse *= 2 - denominator * inverse; // inverse mod 2^64
inverse *= 2 - denominator * inverse; // inverse mod 2^128
inverse *= 2 - denominator * inverse; // inverse mod 2^256
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
// This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is
// less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1
// is no longer required.
result = prod0 * inverse;
return result;
}
}
function quantize_data(
bytes[] memory data,
uint256[] memory decimals,
uint256[] memory scales
) external pure returns (int256[] memory quantized_data) {
quantized_data = new int256[](data.length);
for (uint i; i < data.length; i++) {
int x = abi.decode(data[i], (int256));
bool neg = x < 0;
if (neg) x = -x;
uint denom = 10 ** decimals[i];
uint scale = 1 << scales[i];
uint output = mulDiv(uint256(x), scale, denom);
if (mulmod(uint256(x), scale, denom) * 2 >= denom) {
output += 1;
}
quantized_data[i] = neg ? -int256(output) : int256(output);
}
}
function to_field_element(
int128[] memory quantized_data
) public pure returns (uint256[] memory output) {
output = new uint256[](quantized_data.length);
for (uint i; i < quantized_data.length; i++) {
output[i] = uint256(quantized_data[i] + int(ORDER)) % ORDER;
}
}
}

View File

@@ -1,12 +0,0 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.17;
contract TestReads {
int[] public arr;
constructor(int256[] memory _numbers) {
for (uint256 i = 0; i < _numbers.length; i++) {
arr.push(_numbers[i]);
}
}
}

11
data.sh
View File

@@ -1,11 +0,0 @@
#! /bin/bash
mkdir data
cd data
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gzip -d *.gz

2
docs/python/build.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
sphinx-build ./src build

View File

@@ -0,0 +1,4 @@
ezkl==0.0.0
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

29
docs/python/src/conf.py Normal file
View File

@@ -0,0 +1,29 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
version = release
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
]
autosummary_generate = True
autosummary_imported_members = True
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']

11
docs/python/src/index.rst Normal file
View File

@@ -0,0 +1,11 @@
.. extension documentation master file, created by
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
ezkl python bindings
================================================
.. automodule:: ezkl
:members:
:undoc-members:

View File

@@ -42,8 +42,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -66,8 +66,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -90,8 +90,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -203,8 +203,8 @@ where
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let op = PolyOp::Conv {
padding: [(PADDING, PADDING); 2],
stride: (STRIDE, STRIDE),
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
};
let x = config
.layer_config
@@ -308,6 +308,7 @@ pub fn runconv() {
tst_lbl: _,
..
} = MnistBuilder::new()
.base_path("examples/data")
.label_format_digit()
.training_set_length(50_000)
.validation_set_length(10_000)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: i128, const LOOKUP_MAX: i128> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;

View File

@@ -251,7 +251,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"ezkl.setup_test_evm_witness(\n",
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
@@ -333,7 +333,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -354,7 +354,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -462,7 +462,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -482,7 +482,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -510,7 +510,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -535,7 +535,7 @@
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -567,7 +567,7 @@
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
},
"orig_nbformat": 4
},

View File

@@ -249,7 +249,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -278,7 +278,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -299,7 +299,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n"
]
},
@@ -518,7 +518,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -538,7 +538,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -566,7 +566,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -591,7 +591,7 @@
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -623,7 +623,7 @@
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
@@ -654,4 +654,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -150,7 +150,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -170,7 +170,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -192,7 +192,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -303,4 +303,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -352,14 +352,8 @@
"# Specify all the files we need\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.ezkl')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')\n",
"cal_data_path = os.path.join('cal_data.json')"
"cal_data_path = os.path.join('calibration.json')"
]
},
{
@@ -424,7 +418,7 @@
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"res = ezkl.gen_settings()\n",
"assert res == True\n",
"\n"
]
@@ -443,7 +437,7 @@
"\n",
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
"# You may want to increase the max logrows if accuracy is a concern\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
"res = await ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])"
]
},
{
@@ -463,7 +457,7 @@
},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
@@ -484,7 +478,7 @@
},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs()"
]
},
{
@@ -504,17 +498,10 @@
},
"outputs": [],
"source": [
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"res = ezkl.setup()\n",
"\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
"assert res == True"
]
},
{
@@ -539,7 +526,7 @@
"# now generate the witness file\n",
"witness_path = os.path.join('witness.json')\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness()\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -559,13 +546,7 @@
"\n",
"proof_path = os.path.join('proof.json')\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
"\n",
"print(proof)\n",
"assert os.path.isfile(proof_path)"
@@ -585,11 +566,7 @@
"source": [
"# verify our proof\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
@@ -664,12 +641,9 @@
"sol_code_path = os.path.join('Verifier.sol')\n",
"abi_path = os.path.join('Verifier.abi')\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path\n",
"res = await ezkl.create_evm_verifier(\n",
" sol_code_path=sol_code_path,\n",
" abi_path=abi_path, \n",
" )\n",
"\n",
"assert res == True\n",
@@ -757,7 +731,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -467,7 +467,7 @@
"outputs": [],
"source": [
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -494,7 +494,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -508,7 +508,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -625,4 +625,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -195,7 +195,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -222,7 +222,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -236,7 +236,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -179,7 +179,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -202,7 +202,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -214,7 +214,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -313,4 +313,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -241,7 +241,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -270,7 +270,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -291,7 +291,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -420,7 +420,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -451,7 +451,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -472,7 +472,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",
@@ -151,7 +152,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -174,7 +175,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs(settings_path = settings_path)"
"res = await ezkl.get_srs(settings_path = settings_path)"
]
},
{
@@ -187,7 +188,7 @@
"# now generate the witness file \n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -283,4 +284,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -155,7 +155,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -178,7 +178,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -190,7 +190,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -289,4 +289,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -233,7 +233,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -262,7 +262,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
@@ -315,7 +315,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
]
},
{
@@ -429,7 +429,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -460,7 +460,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -481,7 +481,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",

View File

@@ -193,7 +193,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -216,7 +216,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -228,7 +228,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -347,4 +347,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -142,7 +142,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -165,7 +165,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -177,7 +177,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -276,4 +276,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -347,7 +347,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -370,7 +370,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -383,7 +383,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -490,4 +490,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -0,0 +1,279 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Logistic Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package for a Logistic Regression model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LogisticRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LogisticRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -139,7 +139,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -180,7 +180,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -193,7 +193,7 @@
"# now generate the witness file \n",
"witness_path = \"lstmwitness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -300,4 +300,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -7,9 +7,9 @@
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://postgresapp.com/. \n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
@@ -21,23 +21,84 @@
"source": [
"import os\n",
"import getpass\n",
"\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
"os.system(\"chmod +x e2pg\")\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"os.system(\"echo $RLPS_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
@@ -79,11 +140,13 @@
"import json\n",
"import os\n",
"\n",
"# import logging\n",
"# # # uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)"
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
@@ -176,6 +239,7 @@
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
@@ -183,9 +247,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -194,7 +258,7 @@
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
@@ -210,9 +274,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -229,22 +293,6 @@
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -253,10 +301,21 @@
},
"outputs": [],
"source": [
"# setup kzg params\n",
"params_path = os.path.join('kzg.params')\n",
"import subprocess\n",
"import os\n",
"\n",
"res = ezkl.get_srs(params_path, settings_filename)"
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
@@ -306,16 +365,13 @@
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"params_path = os.path.join('kzg.params')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path,\n",
" params_path,\n",
" settings_filename,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
@@ -331,11 +387,14 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
@@ -360,126 +419,14 @@
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" params_path,\n",
" \"single\",\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n",
"# verify\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_filename,\n",
" vk_path,\n",
" params_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7tAa-DFAtvS"
},
"source": [
"# Part 2 (Using the ZK Computational Graph Onchain!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Ym91kaVAIB6"
},
"source": [
"**Now How Do We Do It Onchain?????**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
"print(params_path)\n",
"print(settings_filename)\n",
"\n",
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" params_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.system(\"killall -9 e2pg\");"
"\n"
]
}
],
@@ -501,7 +448,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -323,7 +323,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"assert res == True"
]
},
@@ -348,7 +348,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs(settings_path)"
"res = await ezkl.get_srs(settings_path)"
]
},
{
@@ -362,7 +362,7 @@
"# now generate the witness file\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -469,7 +469,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test_1.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -502,7 +502,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -525,7 +525,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -284,7 +289,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
]
},
{
@@ -304,7 +309,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -316,7 +321,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -215,7 +215,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -235,7 +235,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -247,7 +247,7 @@
"# now generate the witness file\n",
"witness_path = \"ae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",
@@ -449,7 +451,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n",
"print(\"verified\")"
]
@@ -471,7 +473,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -483,7 +485,7 @@
"# now generate the witness file \n",
"witness_path = \"vae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -845,7 +845,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"assert res == True"
]
},
@@ -870,7 +870,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -881,7 +881,7 @@
},
"outputs": [],
"source": [
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -993,4 +993,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -261,7 +261,7 @@
"source": [
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
"\n",
"def setup(i):\n",
"async def setup(i):\n",
" # file names\n",
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
@@ -282,7 +282,7 @@
"\n",
" # generate settings for the current model\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
" res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" assert res == True\n",
"\n",
" # load settings and print them to the console\n",
@@ -303,11 +303,11 @@
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
"\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
"for i in range(2):\n",
" setup(i)\n"
" await setup(i)\n"
]
},
{
@@ -414,7 +414,7 @@
"outputs": [],
"source": [
"for i in range(2):\n",
" setup(i)"
" await setup(i)"
]
},
{
@@ -466,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
},
"orig_nbformat": 4
},

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -193,7 +174,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -215,7 +196,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -227,7 +208,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -215,7 +215,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -229,7 +229,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -265,7 +265,7 @@
" # Serialize data into file:\n",
"json.dump( data, open(data_path_faulty, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"assert os.path.isfile(witness_path_faulty)"
]
},
@@ -310,7 +310,7 @@
"# Serialize data into file:\n",
"json.dump( data, open(data_path_truthy, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"res = await ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"assert os.path.isfile(witness_path_truthy)"
]
},
@@ -519,4 +519,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -193,7 +193,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -205,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -290,7 +290,7 @@
"source": [
"# Generate a larger SRS. This is needed for the aggregated proof\n",
"\n",
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
]
},
{
@@ -374,7 +374,7 @@
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
"\n",
"res = ezkl.create_evm_verifier_aggr(\n",
"res = await ezkl.create_evm_verifier_aggr(\n",
" [settings_path],\n",
" aggregate_vk_path,\n",
" sol_code_path,\n",
@@ -404,4 +404,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -157,6 +157,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b78d3cbf",
"metadata": {},
"outputs": [],
"source": [
@@ -170,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -204,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -298,7 +299,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,

View File

@@ -169,7 +169,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -191,7 +191,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -203,7 +203,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -170,7 +170,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -192,7 +192,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -149,7 +149,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -171,7 +171,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -183,7 +183,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -282,4 +282,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -250,7 +250,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -297,7 +297,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
@@ -411,7 +411,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",

View File

@@ -167,7 +167,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -187,7 +187,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -209,7 +209,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -221,7 +221,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -320,4 +320,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -180,7 +180,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -202,7 +202,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -214,7 +214,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -420,7 +420,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
}
@@ -446,4 +446,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -633,7 +637,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -642,7 +646,7 @@
"metadata": {},
"outputs": [],
"source": [
"ezkl.get_srs( settings_path)"
"await ezkl.get_srs( settings_path)"
]
},
{
@@ -679,7 +683,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -520,7 +525,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{
@@ -567,7 +572,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
@@ -49,7 +49,11 @@
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os"
"import os\n",
"\n",
"import logging\n",
"\n",
"logging.basicConfig(level=logging.INFO)"
]
},
{
@@ -63,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -71,7 +75,15 @@
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"source": [
"from torch import nn\n",
"import torch\n",
@@ -133,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -141,7 +153,18 @@
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1715422870\n",
"1714818070\n",
"https://api.coingecko.com/api/v3/coins/ethereum/market_chart/range?vs_currency=usd&from=1714818070&to=1715422870\n",
"<Response [200]>\n"
]
}
],
"source": [
"\n",
"def get_url(coin, currency, start, end):\n",
@@ -174,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -183,7 +206,115 @@
"id": "WSj1Uxln65vf",
"outputId": "51422d71-9680-4b51-c4df-e400d20f988b"
},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>time</th>\n",
" <th>prices</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1714820485367</td>\n",
" <td>3146.785806</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1714824033868</td>\n",
" <td>3127.968728</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1714828058243</td>\n",
" <td>3156.141681</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1714831650751</td>\n",
" <td>3124.834064</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1714834972229</td>\n",
" <td>3133.115333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>163</th>\n",
" <td>1715407579346</td>\n",
" <td>2918.049749</td>\n",
" </tr>\n",
" <tr>\n",
" <th>164</th>\n",
" <td>1715411090715</td>\n",
" <td>2920.330834</td>\n",
" </tr>\n",
" <tr>\n",
" <th>165</th>\n",
" <td>1715414554830</td>\n",
" <td>2923.986611</td>\n",
" </tr>\n",
" <tr>\n",
" <th>166</th>\n",
" <td>1715418419843</td>\n",
" <td>2910.537671</td>\n",
" </tr>\n",
" <tr>\n",
" <th>167</th>\n",
" <td>1715421675338</td>\n",
" <td>2907.702307</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>168 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" time prices\n",
"0 1714820485367 3146.785806\n",
"1 1714824033868 3127.968728\n",
"2 1714828058243 3156.141681\n",
"3 1714831650751 3124.834064\n",
"4 1714834972229 3133.115333\n",
".. ... ...\n",
"163 1715407579346 2918.049749\n",
"164 1715411090715 2920.330834\n",
"165 1715414554830 2923.986611\n",
"166 1715418419843 2910.537671\n",
"167 1715421675338 2907.702307\n",
"\n",
"[168 rows x 2 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(new_data)\n",
"df\n"
@@ -200,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -225,7 +356,98 @@
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.execute:SRS already exists at that path\n",
"INFO:ezkl.execute:num calibration batches: 1\n",
"INFO:ezkl.execute:read 16777476 bytes from file (vector of len = 16777476)\n",
"WARNING:ezkl.execute:\n",
"\n",
" <------------- Numerical Fidelity Report (input_scale: 4, param_scale: 4, scale_input_multiplier: 10) ------------->\n",
"\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"| mean_error | median_error | max_error | min_error | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"| -727.9929 | -727.9929 | -727.9929 | -727.9929 | 727.9929 | 727.9929 | 727.9929 | 727.9929 | 529973.7 | -0.24999964 | 0.24999964 |\n",
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
"\n",
"\n",
"INFO:ezkl.execute:file hash: 41509f380362a8d14401c5ae92073154922fe23e45459ce6f696f58607655db7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"run_args\": {\n",
" \"tolerance\": {\n",
" \"val\": 0.0,\n",
" \"scale\": 1.0\n",
" },\n",
" \"input_scale\": 4,\n",
" \"param_scale\": 4,\n",
" \"scale_rebase_multiplier\": 10,\n",
" \"lookup_range\": [\n",
" 0,\n",
" 0\n",
" ],\n",
" \"logrows\": 6,\n",
" \"num_inner_cols\": 2,\n",
" \"variables\": [\n",
" [\n",
" \"batch_size\",\n",
" 1\n",
" ]\n",
" ],\n",
" \"input_visibility\": \"Private\",\n",
" \"output_visibility\": \"Public\",\n",
" \"param_visibility\": \"Private\",\n",
" \"div_rebasing\": false,\n",
" \"rebase_frac_zero_constants\": false,\n",
" \"check_mode\": \"UNSAFE\",\n",
" \"commitment\": \"KZG\"\n",
" },\n",
" \"num_rows\": 21,\n",
" \"total_assignments\": 42,\n",
" \"total_const_size\": 0,\n",
" \"total_dynamic_col_size\": 0,\n",
" \"num_dynamic_lookups\": 0,\n",
" \"num_shuffles\": 0,\n",
" \"total_shuffle_col_size\": 0,\n",
" \"model_instance_shapes\": [\n",
" [\n",
" 1\n",
" ]\n",
" ],\n",
" \"model_output_scales\": [\n",
" 8\n",
" ],\n",
" \"model_input_scales\": [\n",
" 4\n",
" ],\n",
" \"module_sizes\": {\n",
" \"polycommit\": [],\n",
" \"poseidon\": [\n",
" 0,\n",
" [\n",
" 0\n",
" ]\n",
" ]\n",
" },\n",
" \"required_lookups\": [],\n",
" \"required_range_checks\": [],\n",
" \"check_mode\": \"UNSAFE\",\n",
" \"version\": \"0.0.0\",\n",
" \"num_blinding_factors\": null,\n",
" \"timestamp\": 1715422871248\n",
"}\n"
]
}
],
"source": [
"# generate settings\n",
"onnx_filename = os.path.join('lol.onnx')\n",
@@ -236,9 +458,9 @@
"\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
"await ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
"res = ezkl.get_srs(settings_filename)\n",
"res = await ezkl.get_srs(settings_filename)\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
"\n",
"# show the settings.json\n",
@@ -259,7 +481,24 @@
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:VK took 0.8\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:PK took 0.2\n",
"INFO:ezkl.pfsys:saving verification key 💾\n",
"INFO:ezkl.pfsys:done saving verification key ✅\n",
"INFO:ezkl.pfsys:saving proving key 💾\n",
"INFO:ezkl.pfsys:done saving proving key ✅\n"
]
}
],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
@@ -281,20 +520,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"res = await ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@@ -302,7 +541,34 @@
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys:loading proving key from \"test.pk\"\n",
"INFO:ezkl.pfsys:done loading proving key ✅\n",
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:proof started...\n",
"INFO:ezkl.graph.model:model layout...\n",
"INFO:ezkl.pfsys:proof took 0.15\n",
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
"INFO:ezkl.pfsys:done loading verification key ✅\n",
"INFO:ezkl.execute:verify took 0.2\n",
"INFO:ezkl.execute:verified: true\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"verified\n"
]
}
],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
@@ -351,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -360,7 +626,26 @@
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
"INFO:ezkl.pfsys:done loading verification key ✅\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test.vk\n",
"settings.json\n"
]
}
],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
@@ -370,7 +655,7 @@
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
@@ -381,9 +666,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.eth:using chain 31337\n",
"INFO:ezkl.execute:Contract deployed at: 0x998abeb3e57409262ae5b751f60747921b33613e\n"
]
}
],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
@@ -391,8 +685,9 @@
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"sol_code_path = 'test.sol'\n",
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -406,16 +701,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:ezkl.eth:using chain 31337\n",
"INFO:ezkl.eth:estimated verify gas cost: 399775\n",
"INFO:ezkl.execute:Solidity verification result: true\n"
]
}
],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -451,7 +756,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1 @@
[{"type":"function","name":"verifyProof","inputs":[{"name":"proof","type":"bytes","internalType":"bytes"},{"name":"instances","type":"uint256[]","internalType":"uint256[]"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"nonpayable"}]

View File

@@ -629,7 +629,7 @@
"source": [
"\n",
"\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -660,7 +660,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.get_srs(settings_path)"
"res = await ezkl.get_srs(settings_path)"
]
},
{
@@ -680,7 +680,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -807,7 +807,7 @@
"settings_path = os.path.join('settings.json')\n",
"\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
@@ -847,7 +847,7 @@
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -868,7 +868,7 @@
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -242,6 +242,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "2007dc77",
"metadata": {},
"outputs": [],
"source": [
@@ -257,6 +258,7 @@
},
{
"cell_type": "markdown",
"id": "ab993958",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
@@ -272,7 +274,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -284,12 +286,13 @@
"source": [
"# now generate the witness file \n",
"\n",
"witness = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"witness = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "markdown",
"id": "ad58432e",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
@@ -317,6 +320,7 @@
},
{
"cell_type": "markdown",
"id": "1746c8d1",
"metadata": {},
"source": [
"We can now create an EVM verifier contract from our circuit. This contract will be deployed to the chain we are using. In this case we are using a local anvil instance."
@@ -325,15 +329,15 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d1920c0f",
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
@@ -344,6 +348,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0fd7f22b",
"metadata": {},
"outputs": [],
"source": [
@@ -351,7 +356,7 @@
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = ezkl.deploy_evm(\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
@@ -362,6 +367,7 @@
},
{
"cell_type": "markdown",
"id": "9c0dffab",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. \n",
@@ -371,6 +377,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cc888848",
"metadata": {},
"outputs": [],
"source": []
@@ -378,6 +385,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c2db14d7",
"metadata": {},
"outputs": [],
"source": [
@@ -385,7 +393,7 @@
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = ezkl.create_evm_data_attestation(\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
@@ -396,12 +404,13 @@
{
"cell_type": "code",
"execution_count": null,
"id": "5a018ba6",
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = ezkl.deploy_da_evm(\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
@@ -412,6 +421,7 @@
},
{
"cell_type": "markdown",
"id": "2adad845",
"metadata": {},
"source": [
"Now we can pull in the data from the contract and calculate a new set of coordinates. We then rotate the world by 1 transform and submit the proof to the contract. The contract could then update the world rotation (logic not inserted here). For demo purposes we do this repeatedly, rotating the world by 1 transform. "
@@ -444,6 +454,7 @@
},
{
"cell_type": "markdown",
"id": "90eda56e",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
@@ -528,7 +539,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -193,7 +193,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -215,7 +215,7 @@
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs( settings_path)"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -227,7 +227,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -346,4 +346,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

View File

@@ -0,0 +1,9 @@
{
"input_data": [
[
1.514470100402832, 1.519423007965088, 1.5182757377624512,
1.5262789726257324, 1.5298409461975098
]
],
"output_data": [[-0.1862019]]
}

Binary file not shown.

View File

@@ -104,5 +104,5 @@ json.dump(data, open("input.json", 'w'))
# ezkl.gen_settings("network.onnx", "settings.json")
# !RUST_LOG = full
# res = ezkl.calibrate_settings(
# res = await ezkl.calibrate_settings(
# "input.json", "network.onnx", "settings.json", "resources")

View File

@@ -1,6 +1,6 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"version": "v10.4.2",
"publishConfig": {
"access": "public"
},
@@ -20,16 +20,16 @@
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
"@ethereumjs/common": "4.0.0",
"@ethereumjs/evm": "2.0.0",
"@ethereumjs/statemanager": "2.0.0",
"@ethereumjs/tx": "5.0.0",
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ezkljs/engine": "10.4.2",
"ethers": "6.7.1",
"json-bigint": "1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",

View File

@@ -6,34 +6,34 @@ settings:
dependencies:
'@ethereumjs/common':
specifier: ^4.0.0
specifier: 4.0.0
version: 4.0.0
'@ethereumjs/evm':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/statemanager':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/tx':
specifier: ^5.0.0
specifier: 5.0.0
version: 5.0.0
'@ethereumjs/util':
specifier: ^9.0.0
specifier: 9.0.0
version: 9.0.0
'@ethereumjs/vm':
specifier: ^7.0.0
specifier: 7.0.0
version: 7.0.0
'@ethersproject/abi':
specifier: ^5.7.0
specifier: 5.7.0
version: 5.7.0
'@ezkljs/engine':
specifier: ^9.4.4
version: 9.4.4
specifier: "10.4.2"
version: "10.4.2"
ethers:
specifier: ^6.7.1
specifier: 6.7.1
version: 6.7.1
json-bigint:
specifier: ^1.0.0
specifier: 1.0.0
version: 1.0.0
devDependencies:
@@ -397,8 +397,8 @@ packages:
'@ethersproject/strings': 5.7.0
dev: false
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
/@ezkljs/engine@10.4.2:
resolution: {integrity: "sha512-1GNB4vChbaQ1ALcYbEbM/AFoh4QWtswpzGCO/g9wL8Ep6NegM2gQP/uWICU7Utl0Lj1DncXomD7PUhFSXhtx8A=="}
dependencies:
'@types/json-bigint': 1.0.2
json-bigint: 1.0.0

View File

@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
exit 1
fi
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi
@@ -126,7 +126,6 @@ elif [ "$PLATFORM" == "macos" ]; then
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
else
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
@@ -143,7 +142,7 @@ elif [ "$PLATFORM" == "macos" ]; then
fi
elif [ "$PLATFORM" == "linux" ]; then
if [ "${ARCHITECTURE}" = "amd64" ]; then
if [ "$ARCHITECTURE" == "amd64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
@@ -155,9 +154,20 @@ elif [ "$PLATFORM" == "linux" ]; then
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
elif [ "$ARCHITECTURE" == "aarch64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-aarch64.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
else
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
echo "Non aarch ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
else

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.14,<0.15"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.pytest.ini_options]
@@ -8,6 +8,7 @@ addopts = "-rfEX -p pytester --strict-markers"
testpaths = [
"tests/python/*_tests.py",
]
asyncio_mode = "auto"
[project]
name = "ezkl"

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.1
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -1,7 +1,7 @@
// ignore file if compiling for wasm
#[cfg(not(target_arch = "wasm32"))]
use clap::Parser;
use clap::{CommandFactory, Parser};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(not(target_arch = "wasm32"))]
@@ -11,7 +11,7 @@ use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, error, info};
use log::{error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
@@ -27,19 +27,27 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
info!("Running with ICICLE GPU");
if let Some(generator) = args.generator {
ezkl::commands::print_completions(generator, &mut Cli::command());
Ok(())
} else if let Some(command) = args.command {
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
info!("Running with ICICLE GPU");
} else {
info!("Running with CPU");
}
info!("command: \n {}", &command.as_json().to_colored_json_auto()?);
let res = run(command).await;
match &res {
Ok(_) => info!("succeeded"),
Err(e) => error!("failed: {}", e),
};
res.map(|_| ())
} else {
info!("Running with CPU");
Err("No command provided".into())
}
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
let res = run(args.command).await;
match &res {
Ok(_) => info!("succeeded"),
Err(e) => error!("failed: {}", e),
};
res.map(|_| ())
}
#[cfg(target_arch = "wasm32")]

View File

@@ -44,12 +44,11 @@ impl PolyCommitChip {
/// Commit to the message using the KZG commitment scheme
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
message: Vec<Scheme::Scalar>,
degree: u32,
num_unusable_rows: u32,
params: &Scheme::ParamsProver,
) -> Vec<G1Affine> {
let k = params.k();
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
let domain = halo2_proofs::poly::EvaluationDomain::new(2, k);
let n = 2_u64.pow(k) - num_unusable_rows as u64;
let num_poly = (message.len() / n as usize) + 1;
let mut poly = vec![domain.empty_lagrange(); num_poly];

View File

@@ -24,7 +24,7 @@ use crate::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
@@ -80,14 +80,18 @@ impl ToFlags for CheckMode {
impl From<String> for CheckMode {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
"safe" => CheckMode::SAFE,
"unsafe" => CheckMode::UNSAFE,
_ => {
log::error!("Invalid value for CheckMode");
log::warn!("defaulting to SAFE");
CheckMode::SAFE
}
Self::from_str(value.as_str()).unwrap()
}
}
impl FromStr for CheckMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"safe" => Ok(CheckMode::SAFE),
"unsafe" => Ok(CheckMode::UNSAFE),
_ => Err("Invalid value for CheckMode".to_string()),
}
}
}
@@ -345,7 +349,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {
@@ -956,20 +960,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
let res = op.layout(self, region, values)?;
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
if let Some(claimed_output) = &res {
// during key generation this will be unknown vals so we use this as a flag to check
let mut is_assigned = !claimed_output.any_unknowns()?;
for val in values.iter() {
is_assigned = is_assigned && !val.any_unknowns()?;
}
if is_assigned {
op.safe_mode_check(claimed_output, values)?;
}
}
};
Ok(res)
op.layout(self, region, values)
}
}

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::i64_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -29,15 +29,15 @@ pub enum HybridOp {
dim: usize,
},
SumPool {
padding: [(usize, usize); 2],
stride: (usize, usize),
kernel_shape: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
},
MaxPool2d {
padding: [(usize, usize); 2],
stride: (usize, usize),
pool_dims: (usize, usize),
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
},
ReduceMin {
axes: Vec<usize>,
@@ -71,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -85,93 +85,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => tensor::ops::nonlinearities::softmax_axes(
&x,
input_scale.into(),
output_scale.into(),
axes,
),
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater(&x, &y)?
}
HybridOp::GreaterEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater_equal(&x, &y)?
}
HybridOp::Less => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less(&x, &y)?
}
HybridOp::LessEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less_equal(&x, &y)?
}
HybridOp::Equals => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::equals(&x, &y)?
}
};
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
match self {
@@ -201,12 +114,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => format!(
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
@@ -253,9 +166,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
*padding,
*stride,
*kernel_shape,
padding,
stride,
kernel_shape,
*normalized,
)?,
HybridOp::Recip {
@@ -271,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
i128_to_felt(input_scale.0 as i128),
i128_to_felt(output_scale.0 as i128),
i64_to_felt(input_scale.0 as i64),
i64_to_felt(output_scale.0 as i64),
)?
} else {
layouts::nonlinearity(
@@ -296,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
i64_to_felt(denom.0 as i64),
)?
} else {
layouts::nonlinearity(
@@ -315,17 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
}
}
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => layouts::max_pool2d(
} => layouts::max_pool(
config,
region,
values[..].try_into()?,
*padding,
*stride,
*pool_dims,
padding,
stride,
pool_dims,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,9 @@ use std::error::Error;
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::{felt_to_i64, i64_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
};
use super::Op;
@@ -132,19 +132,16 @@ impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
let range = range as i64;
(-range, range)
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i64(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
@@ -231,10 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
}
}?;
let output = res.map(|x| i128_to_felt(x));
let output = res.map(|x| i64_to_felt(x));
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Returns the name of the operation
fn as_string(&self) -> String {

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -27,16 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
std::fmt::Debug + Send + Sync + Any
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
fn as_string(&self) -> String;
@@ -71,36 +69,9 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any;
/// Safe mode output checl
fn safe_mode_check(
&self,
claimed_output: &ValTensor<F>,
original_values: &[ValTensor<F>],
) -> Result<(), TensorError> {
let felt_evals = original_values
.iter()
.map(|v| {
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
evals.reshape(v.dims())?;
Ok(evals)
})
.collect::<Result<Vec<_>, _>>()?;
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
let mut output = claimed_output
.get_felt_evals()
.map_err(|_| TensorError::FeltError)?;
output.reshape(claimed_output.dims())?;
assert_eq!(output, ref_op);
Ok(())
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -151,8 +122,8 @@ impl InputType {
*input = T::from_f64(f64_input).unwrap();
}
InputType::Int | InputType::TDim => {
let int_input = input.clone().to_i128().unwrap();
*input = T::from_i128(int_input).unwrap();
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
}
}
@@ -167,7 +138,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -176,12 +147,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
self
}
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
})
}
fn as_string(&self) -> String {
"Input".into()
}
@@ -228,16 +193,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Err(TensorError::WrongMethod)
}
fn as_string(&self) -> String {
"Unknown".into()
@@ -258,7 +220,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
///
pub quantized_values: Tensor<F>,
///
@@ -268,7 +230,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -301,17 +263,13 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
+ for<'de> Deserialize<'de>
+ IntoI64,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
format!("CONST (scale={})", self.quantized_values.scale().unwrap())

View File

@@ -1,6 +1,5 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -32,8 +31,8 @@ pub enum PolyOp {
equation: String,
},
Conv {
padding: [(usize, usize); 2],
stride: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
},
Downsample {
axis: usize,
@@ -41,9 +40,9 @@ pub enum PolyOp {
modulo: usize,
},
DeConv {
padding: [(usize, usize); 2],
output_padding: (usize, usize),
stride: (usize, usize),
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
},
Add,
Sub,
@@ -58,10 +57,13 @@ pub enum PolyOp {
destination: usize,
},
Flatten(Vec<usize>),
Pad([(usize, usize); 2]),
Pad(Vec<(usize, usize)>),
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
@@ -95,7 +97,8 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
+ for<'de> Deserialize<'de>
+ IntoI64,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
@@ -105,10 +108,28 @@ impl<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -120,15 +141,26 @@ impl<
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
PolyOp::Add => "ADD".into(),
PolyOp::Mult => "MULT".into(),
PolyOp::Sub => "SUB".into(),
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
@@ -142,146 +174,6 @@ impl<
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
"multibroadcastto inputs".to_string(),
));
}
inputs[0].expand(shape)
}
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
PolyOp::Not => tensor::ops::not(&inputs[0]),
PolyOp::Downsample {
axis,
stride,
modulo,
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::MoveAxis {
source,
destination,
} => inputs[0].move_axis(*source, *destination),
PolyOp::Flatten(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::Pad(p) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pad inputs".to_string()));
}
tensor::ops::pad(&inputs[0], *p)
}
PolyOp::Add => tensor::ops::add(&inputs),
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult => tensor::ops::mult(&inputs),
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
PolyOp::DeConv {
padding,
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
}
inputs[0].pow(*u)
}
PolyOp::Sum { axes } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("sum inputs".to_string()));
}
tensor::ops::sum_axes(&inputs[0], axes)
}
PolyOp::Prod { axes, .. } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("prod inputs".to_string()));
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
PolyOp::Slice { axis, start, end } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
}?;
Ok(ForwardResult { output: res })
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
@@ -292,6 +184,9 @@ impl<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
@@ -318,7 +213,7 @@ impl<
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -370,9 +265,9 @@ impl<
config,
region,
values[..].try_into()?,
*padding,
*output_padding,
*stride,
padding,
output_padding,
stride,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -388,7 +283,7 @@ impl<
)));
}
let mut input = values[0].clone();
input.pad(*p)?;
input.pad(p.clone(), 0)?;
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
@@ -404,6 +299,7 @@ impl<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {

View File

@@ -9,7 +9,7 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
@@ -133,10 +133,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
max_lookup_inputs: i64,
min_lookup_inputs: i64,
max_range_size: i64,
witness_gen: bool,
check_lookup_range: bool,
assigned_constants: ConstantsMap<F>,
}
@@ -163,7 +164,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants.into_iter());
self.assigned_constants.extend(constants);
}
///
@@ -191,6 +192,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.witness_gen
}
///
pub fn check_lookup_range(&self) -> bool {
self.check_lookup_range
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
@@ -209,6 +215,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: true,
check_lookup_range: true,
assigned_constants: HashMap::new(),
}
}
@@ -246,12 +253,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: false,
check_lookup_range: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -268,6 +281,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
@@ -278,6 +292,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord: usize,
num_inner_cols: usize,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -293,6 +308,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
@@ -364,6 +380,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
starting_linear_coord,
self.num_inner_cols,
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -389,7 +406,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants.into_iter());
constants.extend(local_reg.assigned_constants);
res
})
@@ -546,17 +563,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
pub fn max_lookup_inputs(&self) -> i64 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
pub fn min_lookup_inputs(&self) -> i64 {
self.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i128 {
pub fn max_range_size(&self) -> i64 {
self.max_range_size
}
@@ -574,8 +591,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -599,8 +618,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -630,9 +651,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let mut values_map = values.create_constants_map();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {

View File

@@ -11,19 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
circuit::CircuitError,
fieldutils::i128_to_felt,
tensor::{Tensor, TensorType},
fieldutils::i64_to_felt,
tensor::{IntoI64, Tensor, TensorType},
};
use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
pub type Range = (i64, i64);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
@@ -98,26 +96,25 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
i128_to_felt(chunk)
i64_to_felt(chunk)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
let chunk = chunk as i128;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
let op_f = Op::<F>::f(
&self.nonlinearity,
&[Tensor::from(vec![first_element].into_iter())],
)
.unwrap();
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
.unwrap();
(first_element, op_f.output[0])
}
@@ -133,12 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}
///
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
(range_len / (col_size as i64)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -205,8 +202,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -275,12 +272,12 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
}
///
@@ -297,13 +294,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
i128_to_felt(chunk)
i64_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
@@ -353,7 +350,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -1048,8 +1048,8 @@ mod conv {
&mut region,
&self.inputs,
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1,6 +1,7 @@
use clap::{Parser, Subcommand};
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
@@ -10,7 +11,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
@@ -52,6 +53,8 @@ pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
pub const DEFAULT_VERIFIER_DA_ABI: &str = "verifier_da_abi.json";
/// Default solidity code
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
/// Default calldata path
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
/// Default solidity code for aggregated proofs
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
/// Default solidity code for data attestation
@@ -78,7 +81,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk seperately
/// Default render vk separately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
@@ -253,33 +256,66 @@ lazy_static! {
};
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone, Deserialize, Serialize)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION)]
pub struct Cli {
#[command(subcommand)]
#[allow(missing_docs)]
pub command: Commands,
/// Get the styles for the CLI
pub fn get_styles() -> clap::builder::Styles {
clap::builder::Styles::styled()
.usage(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.header(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.literal(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
)
.invalid(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.error(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.valid(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
)
.placeholder(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
)
}
impl Cli {
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
/// Parse an ezkl configuration from a json
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(arg_json)
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
pub struct Cli {
/// If provided, outputs the completion file for given shell
#[clap(long = "generate", value_parser)]
pub generator: Option<Shell>,
#[command(subcommand)]
#[allow(missing_docs)]
pub command: Option<Commands>,
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -289,8 +325,8 @@ pub enum Commands {
/// Loads model and prints model table
Table {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
@@ -299,30 +335,30 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// Path to output the witness .json file
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS)]
output: PathBuf,
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
output: Option<PathBuf>,
/// Path to the verification key file (optional - solely used to generate kzg commits)
#[arg(short = 'V', long)]
#[arg(short = 'V', long, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// Path to the srs file (optional - solely used to generate kzg commits)
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
},
/// Produces the proving hyperparameters, from run-args
GenSettings {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to generate the circuit settings .json file to
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
@@ -332,51 +368,52 @@ pub enum Commands {
#[cfg(not(target_arch = "wasm32"))]
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to load circuit settings .json file AND overwrite (generated using the gen-settings command).
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)]
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
lookup_safety_margin: i128,
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
lookup_safety_margin: i64,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
scales: Option<Vec<crate::Scale>>,
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
#[arg(
long,
value_delimiter = ',',
allow_hyphen_values = true,
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS,
value_hint = clap::ValueHint::Other
)]
scale_rebase_multiplier: Vec<u32>,
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
only_range_check_rebase: Option<bool>,
},
/// Generates a dummy SRS
#[command(name = "gen-srs", arg_required_else_help = true)]
GenSrs {
/// The path to output the generated SRS
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: PathBuf,
/// number of logrows to use for srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
logrows: usize,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -384,79 +421,79 @@ pub enum Commands {
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
#[arg(long)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Commitment used
#[arg(long, default_value = None)]
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
},
/// Mock aggregate proofs
MockAggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
},
/// setup aggregation circuit :)
/// Setup aggregation circuit and generate pk and vk
SetupAggregate {
/// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
sample_snarks: Vec<PathBuf>,
/// The path to save the desired verification key file to
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
/// Aggregates proofs
Aggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// The path to load the desired proving key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -465,78 +502,79 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum
value_enum,
value_hint = clap::ValueHint::Other
)]
transcript: TranscriptType,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check_mode: CheckMode,
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
/// whether the accumulated proofs are segments of a larger circuit
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
model: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
},
/// Creates pk and vk
Setup {
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to output the verification key file to
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the proving key file to
#[arg(long, default_value = DEFAULT_PK)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The graph witness (optional - used to override fixed values in the circuit)
#[arg(short = 'W', long)]
#[arg(short = 'W', long, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long)]
data: PathBuf,
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information
/// derived from the file information in the data .json file.
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long)]
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// where the input data come from
#[arg(long, default_value = "on-chain")]
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
/// where the output data come from
#[arg(long, default_value = "on-chain")]
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -544,121 +582,136 @@ pub enum Commands {
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long)]
data: PathBuf,
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
#[arg(short = 'P', long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the witness file
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness_path: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Loads model, data, and creates proof
Prove {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to load the desired proving key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_PK)]
pk_path: PathBuf,
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = ProofType::Single,
value_enum
value_enum,
value_hint = clap::ValueHint::Other
)]
proof_type: ProofType,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check_mode: CheckMode,
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_CALLDATA, value_hint = clap::ValueHint::FilePath)]
calldata_path: Option<PathBuf>,
/// The path to the verification key address (only used if the vk is rendered as a separate contract)
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VK_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// The path to the .json data file, which should
/// contain the necessary calldata and account addresses
/// needed to read from all the on-chain
/// view functions that return the data that the network
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -666,104 +719,104 @@ pub enum Commands {
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI)]
abi_path: PathBuf,
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true)]
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_settings: Vec<PathBuf>,
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
render_vk_seperately: Option<bool>,
},
/// Verifies a proof, returning accept or reject
Verify {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the verification key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the verification key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVerifier {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVK {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -771,25 +824,25 @@ pub enum Commands {
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
data: PathBuf,
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
sol_code_path: PathBuf,
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: PathBuf,
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution)
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -797,19 +850,32 @@ pub enum Commands {
#[command(name = "verify-evm")]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
}
impl Commands {
/// Converts the commands to a json string
pub fn as_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
/// Converts a json string to a Commands struct
pub fn from_json(json: &str) -> Self {
serde_json::from_str(json).unwrap()
}
}

File diff suppressed because one or more lines are too long

View File

@@ -1,9 +1,7 @@
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
use crate::commands::CalibrationTarget;
use crate::commands::Commands;
#[cfg(not(target_arch = "wasm32"))]
use crate::commands::H160Flag;
use crate::commands::*;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
@@ -23,6 +21,8 @@ use crate::pfsys::{
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
#[cfg(not(target_arch = "wasm32"))]
use crate::EZKL_BUF_CAPACITY;
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
@@ -66,48 +66,16 @@ use snark_verifier::system::halo2::Config;
use std::error::Error;
use std::fs::File;
#[cfg(not(target_arch = "wasm32"))]
use std::io::BufWriter;
#[cfg(not(target_arch = "wasm32"))]
use std::io::{Cursor, Write};
use std::path::Path;
use std::path::PathBuf;
#[cfg(not(target_arch = "wasm32"))]
use std::process::Command;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::OnceLock;
#[cfg(not(target_arch = "wasm32"))]
use crate::EZKL_BUF_CAPACITY;
#[cfg(not(target_arch = "wasm32"))]
use std::io::BufWriter;
use std::str::FromStr;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
#[cfg(not(target_arch = "wasm32"))]
static _SOLC_REQUIREMENT: OnceLock<bool> = OnceLock::new();
#[cfg(not(target_arch = "wasm32"))]
fn check_solc_requirement() {
info!("checking solc installation..");
_SOLC_REQUIREMENT.get_or_init(|| match Command::new("solc").arg("--version").output() {
Ok(output) => {
debug!("solc output: {:#?}", output);
debug!("solc output success: {:#?}", output.status.success());
if !output.status.success() {
log::error!(
"`solc` check failed: {}",
String::from_utf8_lossy(&output.stderr)
);
return false;
}
debug!("solc check passed, proceeding");
true
}
Err(_) => {
log::error!("`solc` check failed: solc not found");
false
}
});
}
use lazy_static::lazy_static;
lazy_static! {
@@ -152,7 +120,11 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
commitment,
} => gen_srs_cmd(srs_path, logrows as u32, commitment),
} => gen_srs_cmd(
srs_path,
logrows as u32,
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT)?),
),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetSrs {
srs_path,
@@ -160,12 +132,16 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
commitment,
} => get_srs_cmd(srs_path, settings_path, logrows, commitment).await,
Commands::Table { model, args } => table(model, args),
Commands::Table { model, args } => table(model.unwrap_or(DEFAULT_MODEL.into()), args),
Commands::GenSettings {
model,
settings_path,
args,
} => gen_circuit_settings(model, settings_path, args),
} => gen_circuit_settings(
model.unwrap_or(DEFAULT_MODEL.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
args,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::CalibrateSettings {
model,
@@ -178,16 +154,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
max_logrows,
only_range_check_rebase,
} => calibrate(
model,
data,
settings_path,
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse()?),
max_logrows,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::GenWitness {
data,
@@ -195,10 +172,19 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
output,
vk_path,
srs_path,
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
} => gen_witness(
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
data.unwrap_or(DEFAULT_DATA.into()),
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
vk_path,
srs_path,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(
model.unwrap_or(DEFAULT_MODEL.into()),
witness.unwrap_or(DEFAULT_WITNESS.into()),
),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEvmVerifier {
vk_path,
@@ -207,28 +193,60 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
sol_code_path,
abi_path,
render_vk_seperately,
} => create_evm_verifier(
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
),
} => {
create_evm_verifier(
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::EncodeEvmCalldata {
proof_path,
calldata_path,
addr_vk,
} => encode_evm_calldata(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
calldata_path.unwrap_or(DEFAULT_CALLDATA.into()),
addr_vk,
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::CreateEvmVK {
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
} => {
create_evm_vk(
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
abi_path.unwrap_or(DEFAULT_VK_ABI.into()),
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEvmDataAttestation {
settings_path,
sol_code_path,
abi_path,
data,
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
} => {
create_evm_data_attestation(
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_DA_ABI.into()),
data.unwrap_or(DEFAULT_DATA.into()),
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEvmVerifierAggr {
vk_path,
@@ -238,20 +256,27 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
aggregation_settings,
logrows,
render_vk_seperately,
} => create_evm_aggregate_verifier(
vk_path,
srs_path,
sol_code_path,
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
),
} => {
create_evm_aggregate_verifier(
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_AGGREGATED.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
aggregation_settings,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
)
.await
}
Commands::CompileCircuit {
model,
compiled_circuit,
settings_path,
} => compile_circuit(model, compiled_circuit, settings_path),
} => compile_circuit(
model.unwrap_or(DEFAULT_MODEL.into()),
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
),
Commands::Setup {
compiled_circuit,
srs_path,
@@ -260,12 +285,12 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
witness,
disable_selector_compression,
} => setup(
compiled_circuit,
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
srs_path,
vk_path,
pk_path,
vk_path.unwrap_or(DEFAULT_VK.into()),
pk_path.unwrap_or(DEFAULT_PK.into()),
witness,
disable_selector_compression,
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEvmData {
@@ -277,8 +302,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
output_source,
} => {
setup_test_evm_witness(
data,
compiled_circuit,
data.unwrap_or(DEFAULT_DATA.into()),
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
test_data,
rpc_url,
input_source,
@@ -291,13 +316,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
addr,
data,
rpc_url,
} => test_update_account_calls(addr, data, rpc_url).await,
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
#[cfg(not(target_arch = "wasm32"))]
Commands::SwapProofCommitments {
proof_path,
witness_path,
} => swap_proof_commitments_cmd(proof_path, witness_path)
.map(|e| serde_json::to_string(&e).unwrap()),
} => swap_proof_commitments_cmd(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
witness_path.unwrap_or(DEFAULT_WITNESS.into()),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
Commands::Prove {
witness,
@@ -308,20 +337,24 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
proof_type,
check_mode,
} => prove(
witness,
compiled_circuit,
pk_path,
Some(proof_path),
witness.unwrap_or(DEFAULT_WITNESS.into()),
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
pk_path.unwrap_or(DEFAULT_PK.into()),
Some(proof_path.unwrap_or(DEFAULT_PROOF.into())),
srs_path,
proof_type,
check_mode,
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::MockAggregate {
aggregation_snarks,
logrows,
split_proofs,
} => mock_aggregate(aggregation_snarks, logrows, split_proofs),
} => mock_aggregate(
aggregation_snarks,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
),
Commands::SetupAggregate {
sample_snarks,
vk_path,
@@ -333,13 +366,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
commitment,
} => setup_aggregate(
sample_snarks,
vk_path,
pk_path,
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
srs_path,
logrows,
split_proofs,
disable_selector_compression,
commitment,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -352,15 +385,15 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
split_proofs,
commitment,
} => aggregate(
proof_path,
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
aggregation_snarks,
pk_path,
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
srs_path,
transcript,
logrows,
check_mode,
split_proofs,
commitment,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -369,8 +402,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
srs_path,
reduced_srs,
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
.map(|e| serde_json::to_string(&e).unwrap()),
} => verify(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::VerifyAggr {
proof_path,
vk_path,
@@ -379,12 +418,12 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
commitment,
} => verify_aggr(
proof_path,
vk_path,
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
srs_path,
logrows,
reduced_srs,
commitment,
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -396,9 +435,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
private_key,
} => {
deploy_evm(
sol_code_path,
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
rpc_url,
addr_path,
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
optimizer_runs,
private_key,
"Halo2Verifier",
@@ -414,9 +453,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
private_key,
} => {
deploy_evm(
sol_code_path,
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
rpc_url,
addr_path,
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_VK.into()),
optimizer_runs,
private_key,
"Halo2VerifyingKey",
@@ -434,11 +473,11 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
private_key,
} => {
deploy_da_evm(
data,
settings_path,
sol_code_path,
data.unwrap_or(DEFAULT_DATA.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
rpc_url,
addr_path,
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_DA.into()),
optimizer_runs,
private_key,
)
@@ -451,7 +490,16 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
rpc_url,
addr_da,
addr_vk,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
} => {
verify_evm(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
addr_verifier,
rpc_url,
addr_da,
addr_vk,
)
.await
}
}
}
@@ -586,7 +634,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -666,27 +714,24 @@ pub(crate) async fn gen_witness(
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
true,
true,
)?
}
Commitments::IPA => {
@@ -694,22 +739,29 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
true,
true,
)?
}
}
} else {
warn!("SRS for poly commit does not exist (will be ignored)");
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
vk.as_ref(),
None,
true,
true,
)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -887,12 +939,12 @@ impl AccuracyResults {
#[cfg(not(target_arch = "wasm32"))]
#[allow(trivial_casts)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn calibrate(
pub(crate) async fn calibrate(
model_path: PathBuf,
data: PathBuf,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: i128,
lookup_safety_margin: i64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
@@ -910,8 +962,10 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
debug!("num of calibration batches: {}", chunks.len());
let input_shapes = model.graph.input_shapes()?;
let chunks = data.split_into_batches(input_shapes).await?;
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
@@ -1008,6 +1062,7 @@ pub(crate) fn calibrate(
param_scale,
scale_rebase_multiplier,
div_rebasing,
lookup_range: (i64::MIN, i64::MAX),
..settings.run_args.clone()
};
@@ -1043,7 +1098,13 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.forward::<KZGCommitmentScheme<Bn256>>(&mut data.clone(), None, None, true)
.forward::<KZGCommitmentScheme<Bn256>>(
&mut data.clone(),
None,
None,
true,
false,
)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
@@ -1058,7 +1119,7 @@ pub(crate) fn calibrate(
match forward_res {
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
// typically errors will be due to the circuit overflowing the i64 limit
Err(e) => {
error!("forward pass failed: {:?}", e);
pb.inc(1);
@@ -1129,7 +1190,6 @@ pub(crate) fn calibrate(
);
num_passed += 1;
} else {
error!("calibration failed {}", res.err().unwrap());
num_failed += 1;
}
@@ -1233,22 +1293,14 @@ pub(crate) fn calibrate(
);
if matches!(target, CalibrationTarget::Resources { col_overflow: true }) {
let lookup_log_rows = ((best_params.run_args.lookup_range.1
- best_params.run_args.lookup_range.0) as f32)
.log2()
.ceil() as u32
+ 1;
let mut reduction = std::cmp::max(
(best_params
.model_instance_shapes
.iter()
.map(|x| x.iter().product::<usize>())
.sum::<usize>() as f32)
.log2()
.ceil() as u32
+ 1,
lookup_log_rows,
);
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
let module_log_row = best_params.module_constraint_logrows_with_blinding();
let instance_logrows = best_params.log2_total_instances_with_blinding();
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
reduction = std::cmp::max(reduction, instance_logrows);
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
info!(
@@ -1294,7 +1346,7 @@ pub(crate) fn mock(
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_verifier(
pub(crate) async fn create_evm_verifier(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
settings_path: PathBuf,
@@ -1302,18 +1354,18 @@ pub(crate) fn create_evm_verifier(
abi_path: PathBuf,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1331,7 +1383,7 @@ pub(crate) fn create_evm_verifier(
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
@@ -1339,25 +1391,25 @@ pub(crate) fn create_evm_verifier(
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_vk(
pub(crate) async fn create_evm_vk(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1372,7 +1424,7 @@ pub(crate) fn create_evm_vk(
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
@@ -1380,7 +1432,7 @@ pub(crate) fn create_evm_vk(
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_data_attestation(
pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
_sol_code_path: PathBuf,
_abi_path: PathBuf,
@@ -1388,7 +1440,6 @@ pub(crate) fn create_evm_data_attestation(
) -> Result<String, Box<dyn Error>> {
#[allow(unused_imports)]
use crate::graph::{DataSource, VarVisibility};
check_solc_requirement();
let settings = GraphSettings::load(&settings_path)?;
@@ -1428,7 +1479,7 @@ pub(crate) fn create_evm_data_attestation(
let mut f = File::create(_sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0)?;
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
} else {
@@ -1449,7 +1500,6 @@ pub(crate) async fn deploy_da_evm(
runs: usize,
private_key: Option<String>,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let contract_address = deploy_da_verifier_via_solidity(
settings_path,
data,
@@ -1476,7 +1526,6 @@ pub(crate) async fn deploy_evm(
private_key: Option<String>,
contract_name: &str,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
@@ -1493,6 +1542,32 @@ pub(crate) async fn deploy_evm(
Ok(String::new())
}
/// Encodes the calldata for the EVM verifier (both aggregated and single proof)
pub(crate) fn encode_evm_calldata(
proof_path: PathBuf,
calldata_path: PathBuf,
addr_vk: Option<H160Flag>,
) -> Result<Vec<u8>, Box<dyn Error>> {
let snark = Snark::load::<IPACommitmentScheme<G1Affine>>(&proof_path)?;
let flattened_instances = snark.instances.into_iter().flatten();
let encoded = halo2_solidity_verifier::encode_calldata(
addr_vk
.as_ref()
.map(|x| alloy::primitives::Address::from(*x).0)
.map(|x| x.0),
&snark.proof,
&flattened_instances.collect::<Vec<_>>(),
);
log::debug!("Encoded calldata: {:?}", encoded);
File::create(calldata_path)?.write_all(encoded.as_slice())?;
Ok(encoded)
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn verify_evm(
proof_path: PathBuf,
@@ -1502,7 +1577,6 @@ pub(crate) async fn verify_evm(
addr_vk: Option<H160Flag>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::verify_proof_with_data_attestation;
check_solc_requirement();
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
@@ -1535,7 +1609,7 @@ pub(crate) async fn verify_evm(
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_aggregate_verifier(
pub(crate) async fn create_evm_aggregate_verifier(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
sol_code_path: PathBuf,
@@ -1544,7 +1618,6 @@ pub(crate) fn create_evm_aggregate_verifier(
logrows: u32,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
@@ -1590,7 +1663,7 @@ pub(crate) fn create_evm_aggregate_verifier(
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
@@ -1626,8 +1699,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1669,7 +1743,6 @@ pub(crate) async fn setup_test_evm_witness(
) -> Result<String, Box<dyn Error>> {
use crate::graph::TestOnChainData;
info!("run this command in background to keep the instance running for testing");
let mut data = GraphData::from_path(data_path)?;
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
@@ -1705,7 +1778,6 @@ pub(crate) async fn test_update_account_calls(
) -> Result<String, Box<dyn Error>> {
use crate::eth::update_account_calls;
check_solc_requirement();
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
Ok(String::new())
@@ -1736,7 +1808,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1745,7 +1818,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -2187,8 +2260,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -11,8 +11,8 @@ pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
}
}
/// Converts an i128 to a PrimeField element.
pub fn i128_to_felt<F: PrimeField>(x: i128) -> F {
/// Converts an i64 to a PrimeField element.
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
@@ -37,7 +37,7 @@ pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
/// Converts a PrimeField element to an f64.
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
if x > F::from_u128(i128::MAX as u128) {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -50,18 +50,18 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
}
}
/// Converts a PrimeField element to an i128.
pub fn felt_to_i128<F: PrimeField + PartialOrd + Field>(x: F) -> i128 {
if x > F::from_u128(i128::MAX as u128) {
/// Converts a PrimeField element to an i64.
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
-(lower_128 as i128)
-(lower_128 as i64)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
lower_128 as i128
lower_128 as i64
}
}
@@ -79,10 +79,10 @@ mod test {
let res: F = i32_to_felt(2_i32.pow(17));
assert_eq!(res, F::from(131072));
let res: F = i128_to_felt(-15i128);
let res: F = i64_to_felt(-15i64);
assert_eq!(res, -F::from(15));
let res: F = i128_to_felt(2_i128.pow(17));
let res: F = i64_to_felt(2_i64.pow(17));
assert_eq!(res, F::from(131072));
}
@@ -96,10 +96,10 @@ mod test {
}
#[test]
fn felttoi128() {
for x in -(2i128.pow(20))..(2i128.pow(20)) {
let fieldx: F = i128_to_felt::<F>(x);
let xf: i128 = felt_to_i128::<F>(fieldx);
fn felttoi64() {
for x in -(2i64.pow(20))..(2i64.pow(20)) {
let fieldx: F = i64_to_felt::<F>(x);
let xf: i64 = felt_to_i64::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -1,13 +1,13 @@
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
use crate::fieldutils::i64_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::postgres::Client;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use postgres::{Client, NoTls};
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
@@ -21,8 +21,6 @@ use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
@@ -130,7 +128,7 @@ impl FileSourceInner {
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
@@ -152,7 +150,7 @@ impl FileSourceInner {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64,
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
}
}
}
@@ -213,7 +211,9 @@ impl PostgresSource {
}
/// Fetch data from postgres
pub fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
pub async fn fetch(
&self,
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
@@ -234,30 +234,25 @@ impl PostgresSource {
)
};
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
let mut client = Client::connect(&config, NoTls).unwrap();
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).unwrap() {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).await? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
res
})
.join()
.map_err(|_| "failed to fetch data from postgres")?;
}
Ok(vec![res])
}
/// Fetch data from postgres and format it as a FileSource
pub fn fetch_and_format_as_file(
pub async fn fetch_and_format_as_file(
&self,
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
Ok(self
.fetch()?
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
@@ -285,13 +280,13 @@ impl OnChainSource {
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data};
use crate::eth::{
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
};
use log::debug;
// Set up local anvil instance for reading on-chain data
let (anvil, client) = crate::eth::setup_eth_backend(rpc, None).await?;
let address = client.address();
let (client, client_address) = crate::eth::setup_eth_backend(rpc, None).await?;
let mut scales = scales;
// set scales to 1 where data is a field element
@@ -304,7 +299,8 @@ impl OnChainSource {
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
debug!("Calls to accounts: {:?}", calls_to_accounts);
let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?;
let inputs =
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
debug!("Inputs: {:?}", inputs);
let mut quantized_evm_inputs = vec![];
@@ -333,7 +329,7 @@ impl OnChainSource {
inputs.push(t);
}
let used_rpc = rpc.unwrap_or(&anvil.endpoint()).to_string();
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
// Fill the input_data field of the GraphData struct
Ok((
@@ -510,7 +506,7 @@ impl GraphData {
}
///
pub fn split_into_batches(
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
@@ -535,7 +531,7 @@ impl GraphData {
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file()?,
} => data.fetch_and_format_as_file().await?,
};
for (i, shape) in input_shapes.iter().enumerate() {

View File

@@ -6,10 +6,14 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(not(target_arch = "wasm32"))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
pub mod vars;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
@@ -39,7 +43,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -62,13 +66,13 @@ pub use vars::*;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -175,11 +179,11 @@ pub struct GraphWitness {
/// Any hashes of outputs generated during the forward pass
pub processed_outputs: Option<ModuleForwardResult>,
/// max lookup input
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// max lookup input
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// max range check size
pub max_range_size: i128,
pub max_range_size: i64,
}
impl GraphWitness {
@@ -483,7 +487,22 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
/// Calc the number of rows required for lookup tables
pub fn lookup_log_rows(&self) -> u32 {
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32)
.log2()
.ceil() as u32
}
/// Calc the number of rows required for lookup tables
pub fn lookup_log_rows_with_blinding(&self) -> u32 {
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32
+ RESERVED_BLINDING_ROWS as f32)
.log2()
.ceil() as u32
}
fn model_constraint_logrows_with_blinding(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
@@ -495,14 +514,31 @@ impl GraphSettings {
.ceil() as u32
}
/// calculate the number of rows required for the dynamic lookup and shuffle
pub fn dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
(self.total_dynamic_col_size as f64
+ self.total_shuffle_col_size as f64
+ RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
fn module_constraint_logrows(&self) -> u32 {
/// calculate the number of rows required for the module constraints
pub fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
/// calculate the number of rows required for the module constraints
pub fn module_constraint_logrows_with_blinding(&self) -> u32 {
(self.module_sizes.max_constraints() as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
.log2()
@@ -529,6 +565,14 @@ impl GraphSettings {
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances_with_blinding(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>() + RESERVED_BLINDING_ROWS;
// max between 1 and the log2 of the sums
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// save params to file
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
// buf writer
@@ -965,6 +1009,7 @@ impl GraphCircuit {
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
}
@@ -972,7 +1017,7 @@ impl GraphCircuit {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file()?;
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
@@ -987,8 +1032,8 @@ impl GraphCircuit {
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (_, client) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
// quantize the supplied data using the provided scale + QuantizeData.sol
let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
@@ -1051,14 +1096,14 @@ impl GraphCircuit {
Ok(data)
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
@@ -1066,7 +1111,7 @@ impl GraphCircuit {
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
max_range_size: i64,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
@@ -1085,9 +1130,9 @@ impl GraphCircuit {
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: i128,
max_range_size: i64,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
lookup_safety_margin: i64,
) -> Result<(), Box<dyn std::error::Error>> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
@@ -1126,7 +1171,7 @@ impl GraphCircuit {
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
@@ -1181,7 +1226,7 @@ impl GraphCircuit {
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
max_range_size: i64,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
@@ -1240,6 +1285,7 @@ impl GraphCircuit {
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
witness_gen: bool,
check_lookup: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1288,7 +1334,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, witness_gen)?;
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1451,7 +1497,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1461,7 +1508,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),

View File

@@ -65,11 +65,11 @@ pub struct ForwardResult {
/// The outputs of the forward pass.
pub outputs: Vec<Tensor<Fp>>,
/// The maximum value of any input to a lookup operation.
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// The max range check size
pub max_range_size: i128,
pub max_range_size: i64,
}
impl From<DummyPassRes> for ForwardResult {
@@ -117,11 +117,11 @@ pub struct DummyPassRes {
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// min lookup inputs
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// min range check
pub max_range_size: i128,
pub max_range_size: i64,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -538,7 +538,7 @@ impl Model {
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs, false)?;
let res = self.dummy_layout(run_args, &inputs, false, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -582,12 +582,13 @@ impl Model {
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
witness_gen: bool,
check_lookup: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
Ok(res.into())
}
@@ -803,13 +804,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -818,6 +824,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -1194,6 +1201,20 @@ impl Model {
.collect();
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
@@ -1205,25 +1226,11 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("dims: {:?}", node.out_dims());
debug!("output dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1264,8 +1271,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1297,6 +1304,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1309,25 +1317,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}
@@ -1368,6 +1393,7 @@ impl Model {
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
witness_gen: bool,
check_lookup: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1386,7 +1412,8 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
let mut region =
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;

View File

@@ -15,7 +15,7 @@ use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
pub const POSEIDON_LEN_GRAPH: usize = 32;
/// Poseidon number of instancess
/// Poseidon number of instances
pub const POSEIDON_INSTANCES: usize = 1;
/// Poseidon module type
@@ -314,7 +314,6 @@ impl GraphModules {
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
let res = PolyCommitChip::commit::<Scheme>(
x.to_vec(),
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
srs,
);

Some files were not shown because too many files have changed in this diff Show More