Compare commits

...

70 Commits

Author SHA1 Message Date
dante
50740a22df fix: patch pypi whl version labels (#916) 2025-01-27 20:25:03 -05:00
dante
a2624f6303 fix: strict cvx opt bounds to stop prover non-det (#914) 2025-01-24 08:48:50 -05:00
dante
fc5be4f949 fix: syn-sel should be range-checked when overflow (#913) 2025-01-23 12:26:31 -05:00
dante
d0ba505baa fix: node parsing should not panic (#912) 2025-01-22 08:02:29 -05:00
dante
f35688917d fix: rm macos metal bindings from python (#911) 2025-01-21 00:36:57 -05:00
Artem
7ae541ed35 feat: metal acceleration for MSM solving (#909)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2025-01-20 22:17:24 -05:00
dante
675628cd08 fix!: shuffle argument should include an incrementing index (#904)
BREAKING CHANGE: pk and vk will not be backwards compatible
2025-01-17 09:19:10 -05:00
Artem
4fe7290240 fix: rust ci issue with updating swift pm testing files (#908) 2025-01-14 12:00:55 -05:00
dante
3e027db9b6 fix: apply zizmor suggestions to CI (#906)
---------

Co-authored-by: Jseam <hello.jseam@gmail.com>
2025-01-14 12:00:31 -05:00
Artem
e566acc22a fix: swift pm ci issue with updating testing files (#905) 2025-01-13 18:08:04 -05:00
dante
75ea99e81d fix: eager exec of ok_or error prints (#903) 2025-01-11 13:50:57 -05:00
dante
c5354c382d refactor: range check sanity toggled by CHECKMODE (#902) 2025-01-10 22:58:52 +00:00
dante
bdcba5ca61 feat: add gen-random-data helpers func (#901) 2025-01-09 00:14:27 +00:00
dante
6752a05f19 refactor: pregen mv-lookup blinds (#900) 2025-01-08 17:18:46 +00:00
dante
03aefb85eb chore: version mismatch warnings for artifacts (#899) 2025-01-06 16:01:34 +00:00
dante
e86caca8b6 refactor: batched poly reads (#897) 2025-01-06 15:49:47 +00:00
dante
c839a30ae6 fix: clearer duplication functions (#895) 2024-12-31 07:28:02 -05:00
dante
352812b9ac refactor!: simplified decompose op (#892) 2024-12-30 13:44:03 -05:00
dante
d48d0b0b3e fix: get_slice should not use intermediate Vec (#894) 2024-12-27 23:26:22 -05:00
Jseam
8b223354cc fix: add version string and sed (#893) 2024-12-27 14:24:28 -05:00
dante
caa6ef8e16 fix: const filtering strat is size dependent (#891) 2024-12-27 09:43:59 -05:00
Artem
c4354c10a5 fix: ios bindings update action (#886) 2024-12-16 10:49:13 -05:00
dante
c1ce8c88d0 chore: rm wasm serialization checks (#890) 2024-12-12 22:20:29 -05:00
dante
876a9584a1 chore: optimize wasm bundle for speed over size (#889) 2024-12-12 15:35:17 -05:00
dante
7d7f049cc4 chore: neural bag of words example (#888) 2024-12-12 14:20:21 -05:00
dante
96f3fd94b2 feat: ICICLE MSM and NTT integration (#884) 2024-12-07 00:32:09 +00:00
dante
6263510c56 fix: bump pypi-publish to unstable to use twine updates (#881) 2024-12-06 23:19:29 +00:00
Jseam
f5b8ae3213 fix: revert pypi to 1.11.0 (#880) 2024-12-05 14:46:40 -05:00
dante
b2e4e414f0 chore: update pyo3 and add stub (#879) 2024-12-05 10:35:06 -05:00
Dmitry
0b0199e2b7 fix: typo in lib.rs (#877) 2024-12-03 18:46:46 -05:00
dante
5e169bdd17 chore: update tract to 0.21.8-pre (#878) 2024-12-03 16:52:03 -05:00
dante
64cbcb3f7e chore: explicitly compile div op (#876) 2024-11-28 17:14:53 +09:00
dante
ee17f0ff9a chore: generalize the exp to other bases (#875) 2024-11-26 09:31:12 +09:00
Jseam
ee55e7dc19 fix: upgrade run-on-arch (#874) 2024-11-24 14:30:42 +09:00
Jseam
5df83886c7 fix: typo (#873) 2024-11-22 22:47:25 +09:00
Jseam
061ae89c01 ci: remove armv7 and remove v0.0.0 from auto python docs (#872)
* fix: remove armv7 rpi4 uses aarch64

* fix: remove 0.0.0 as it is causing failures
2024-11-22 19:24:12 +09:00
Ethan Cemer
0fc1c3eecd feat: single call DA verifier (#869) 2024-11-19 07:26:17 +00:00
Artem
85302453d9 fix: ios package update workflow tags & tests assets (#868) 2024-11-16 06:35:39 +00:00
dante
523c77c912 feat: lookupless sqrt and rsqrt (#867) 2024-11-10 15:56:38 +00:00
dante
948e5cd4b9 chore: version proof and witness (#865) 2024-11-08 02:55:35 +00:00
dante
00155e585f feat: bounded lookup log argument (#864) 2024-11-07 12:16:55 +00:00
dante
0876faa12c feat: bounded lookup round half to even (#863) 2024-11-01 00:51:15 -04:00
dante
a3c131dac0 feat: lookupless rounding ops (#862) 2024-10-31 11:29:46 -04:00
sebastiandanconia
fd9c2305ac docs: improve cli friendliness (#861)
* Improve clarity of an info!() message

* Replace references to EZKL_REPO_PATH in `--help' output

Command `--help' messages aren't meant to be unduly verbose; we can
write them for common/simple use cases. We continue to support
EZKL_REPO_PATH for users who need it, for example to support
containerized server use cases.

To be clear, by default, EZKL_REPO_PATH = $HOME/.ezkl
2024-10-30 17:25:47 -04:00
dante
a0060f341d chore: rm lookup recip (#859) 2024-10-29 15:57:38 -04:00
dante
17f1d42739 chore: unify leakyrelu and relu (#858) 2024-10-29 10:43:40 -04:00
dante
ebaee9e2b1 feat: lookupless min/max ops (#854) 2024-10-26 08:00:27 -04:00
dante
d51cba589a feat: dynamic lookup overflow (#853) 2024-10-23 23:12:00 -04:00
Artem
1cb1b6e143 feat: iOS Bindings (#846) 2024-10-23 09:58:55 -04:00
Ethan Cemer
d2b683b527 feat: reusable verifier (#821) 2024-10-22 09:10:24 -04:00
Jseam
a06b09ef1f docs: add batch demo (#849) 2024-10-15 08:36:24 +01:00
dante
e5aa48fbd6 chore: support all padding types (#848) 2024-10-05 10:43:12 -04:00
dante
64fbc8a1c9 refactor: lookup-less sign relu abs (#845) 2024-09-17 11:58:58 -04:00
dante
c9f9d17f16 chore: optimized release builds by default (#844) 2024-09-05 15:52:03 +02:00
Ethan Cemer
b49b0487c4 fix: remove console import from index.ts in in-browser-evm-verifier package (#841) 2024-08-30 18:47:59 +01:00
dante
61b7a8e9b5 chore: perf updates (#838) 2024-08-27 09:45:40 -04:00
dante
5dbc7d5176 chore: cache lookup tables (#835) 2024-08-19 00:24:53 -04:00
dante
ada45a3197 chore: swap h2 collections hash for rustc hasher (#832) 2024-08-04 17:28:36 -04:00
dante
616b421967 fix: bump compiler to latest to accomodate latest serde diagnostic (#830) 2024-07-25 07:56:21 -04:00
dante
f64f0ecd23 fix: instance order when using processed params (#829) 2024-07-24 07:58:46 -04:00
dante
5be12b7a54 fix: num groups for conv operations should be specified at load time (#828) 2024-07-18 09:58:41 -04:00
dante
2fd877c716 chore: small worm example (#568) 2024-07-15 09:20:37 -04:00
dante
8197340985 chore: const filtering optimizations (#825) 2024-07-12 12:37:02 +01:00
dante
6855ea1947 feat: parallel polynomial reads in halo2 (#826) 2024-07-12 00:33:14 +01:00
dante
2ca57bde2c chore: bump tract (#823) 2024-06-29 01:53:18 +01:00
Ethan Cemer
390de88194 feat: create_evm_vk python bindings (#818) 2024-06-23 22:21:59 -04:00
dante
cd91f0af26 chore: allow for float lookup safety margins (#817) 2024-06-19 09:12:01 -04:00
dante
4771192823 fix: more verbose io / rw errors (#815) 2024-06-14 12:31:50 -04:00
dante
a863ccc868 chore: update cmd feature flag (#814) 2024-06-11 16:07:49 -04:00
Ethan Cemer
8e6ccc863d feat: all file source kzg commit DA (#812) 2024-06-11 09:32:54 -04:00
156 changed files with 16837 additions and 7300 deletions

View File

@@ -1,4 +0,0 @@
[target.wasm32-unknown-unknown]
runner = 'wasm-bindgen-test-runner'
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
"link-arg=--max-memory=4294967296"]

17
.cargo/config.toml Normal file
View File

@@ -0,0 +1,17 @@
[target.wasm32-unknown-unknown]
runner = 'wasm-bindgen-test-runner'
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
"link-arg=--max-memory=4294967296"]
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

View File

@@ -6,22 +6,15 @@ on:
description: "Test scenario tags"
jobs:
bench_elgamal:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Bench elgamal
run: cargo bench --verbose --bench elgamal
bench_poseidon:
permissions:
contents: read
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -31,10 +24,14 @@ jobs:
run: cargo bench --verbose --bench poseidon
bench_einsum_accum_matmul:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -44,10 +41,14 @@ jobs:
run: cargo bench --verbose --bench accum_einsum_matmul
bench_accum_matmul_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -57,10 +58,14 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu
bench_accum_matmul_relu_overflow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -70,10 +75,14 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu_overflow
bench_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -83,10 +92,14 @@ jobs:
run: cargo bench --verbose --bench relu
bench_accum_dot:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -96,10 +109,14 @@ jobs:
run: cargo bench --verbose --bench accum_dot
bench_accum_conv:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -109,10 +126,14 @@ jobs:
run: cargo bench --verbose --bench accum_conv
bench_accum_sumpool:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -122,10 +143,14 @@ jobs:
run: cargo bench --verbose --bench accum_sumpool
bench_pairwise_add:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -135,10 +160,14 @@ jobs:
run: cargo bench --verbose --bench pairwise_add
bench_accum_sum:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -148,10 +177,14 @@ jobs:
run: cargo bench --verbose --bench accum_sum
bench_pairwise_pow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27

View File

@@ -15,22 +15,30 @@ defaults:
working-directory: .
jobs:
publish-wasm-bindings:
permissions:
contents: read
packages: write
name: publish-wasm-bindings
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -48,7 +56,7 @@ jobs:
run: |
echo '{
"name": "@ezkljs/engine",
"version": "${{ github.ref_name }}",
"version": "${RELEASE_TAG}",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
@@ -181,21 +189,26 @@ jobs:
in-browser-evm-ver-publish:
permissions:
contents: read
packages: write
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)

View File

@@ -6,12 +6,16 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
permissions:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -18,12 +18,17 @@ defaults:
jobs:
linux:
permissions:
contents: read
packages: write
runs-on: GPU
strategy:
matrix:
target: [x86_64]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
@@ -31,9 +36,12 @@ jobs:
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- uses: actions-rs/toolchain@v1
with:
@@ -98,14 +106,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./

View File

@@ -16,6 +16,8 @@ defaults:
jobs:
macos:
permissions:
contents: read
runs-on: macos-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -23,11 +25,21 @@ jobs:
target: [x86_64, universal2-apple-darwin]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -40,11 +52,18 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
@@ -62,6 +81,8 @@ jobs:
path: dist
windows:
permissions:
contents: read
runs-on: windows-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -69,11 +90,21 @@ jobs:
target: [x64, x86]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -86,7 +117,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
@@ -107,6 +138,8 @@ jobs:
path: dist
linux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -114,11 +147,21 @@ jobs:
target: [x86_64]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -129,7 +172,6 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -168,58 +210,9 @@ jobs:
name: wheels
path: dist
# TODO: There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# linux-cross:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# target: [aarch64, armv7]
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
# - name: Install cross-compilation tools for armv7
# if: matrix.target == 'armv7'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
# - name: Build wheels
# uses: PyO3/maturin-action@v1
# with:
# target: ${{ matrix.target }}
# manylinux: auto
# args: --release --out dist --features python-bindings
# - uses: uraimo/run-on-arch-action@v2.5.0
# name: Install built wheel
# with:
# arch: ${{ matrix.target }}
# distro: ubuntu20.04
# githubToken: ${{ github.token }}
# install: |
# apt-get update
# apt-get install -y --no-install-recommends python3 python3-pip
# pip3 install -U pip
# run: |
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
# python3 -c "import ezkl"
# - name: Upload wheels
# uses: actions/upload-artifact@v3
# with:
# name: wheels
# path: dist
musllinux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -228,11 +221,21 @@ jobs:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -276,6 +279,8 @@ jobs:
path: dist
musllinux-cross:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -283,14 +288,22 @@ jobs:
platform:
- target: aarch64-unknown-linux-musl
arch: aarch64
- target: armv7-unknown-linux-musleabihf
arch: armv7
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -308,7 +321,7 @@ jobs:
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@v2.5.0
- uses: uraimo/run-on-arch-action@v2.8.1
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}
@@ -334,8 +347,6 @@ jobs:
permissions:
id-token: write
if: "startsWith(github.ref, 'refs/tags/')"
# TODO: Uncomment if linux-cross is working
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@v3
@@ -343,35 +354,34 @@ jobs:
name: wheels
- name: List Files
run: ls -R
# Both publish steps will fail if there is no trusted publisher setup
# On failure the publish step will then simply continue to the next one
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@unstable/v1
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
permissions:
contents: read
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}
commit_ref: ${{ github.ref_name }}

View File

@@ -10,6 +10,9 @@ on:
- "*"
jobs:
create-release:
permissions:
contents: read
packages: write
name: create-release
runs-on: ubuntu-22.04
if: startsWith(github.ref, 'refs/tags/')
@@ -33,6 +36,9 @@ jobs:
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
permissions:
contents: read
packages: write
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -45,11 +51,14 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -91,6 +100,10 @@ jobs:
asset_content_type: application/octet-stream
build-release:
permissions:
contents: read
packages: write
issues: write
name: build-release
needs: ["create-release"]
runs-on: ${{ matrix.os }}
@@ -106,32 +119,34 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-02-06
rust: nightly-2024-07-18
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -181,9 +196,18 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"

View File

@@ -19,38 +19,66 @@ env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
fr-age-test:
permissions:
contents: read
runs-on: large-self-hosted
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: fr age Mock
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
build:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build
run: cargo build --verbose
docs:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Docs
run: cargo doc --verbose
library-tests:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -65,48 +93,54 @@ jobs:
- name: Library tests (original lookup)
run: cargo nextest run --lib --verbose --no-default-features --features ezkl
ultra-overflow-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run --release lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# ultra-overflow-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - uses: mwilliamson/setup-wasmtime-action@v2
# with:
# wasmtime-version: "3.0.1"
# - name: Install wasm32-wasi
# run: rustup target add wasm32-wasi
# - name: Install cargo-wasi
# run: cargo install cargo-wasi
# # - name: Matmul overflow (wasi)
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# # - name: Conv overflow (wasi)
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
# - name: lookup overflow
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Matmul overflow
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv overflow
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
# - name: Conv + relu overflow
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
ultra-overflow-tests_og-lookup:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -134,12 +168,16 @@ jobs:
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
ultra-overflow-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -167,12 +205,16 @@ jobs:
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
model-serialization:
permissions:
contents: read
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -183,63 +225,60 @@ jobs:
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
wasm32-tests:
permissions:
contents: read
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- uses: nanasess/setup-chromedriver@v2
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Circuit Render
run: cargo nextest run --release --verbose tests::tutorial_
mock-proving-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and bounded lookup log
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params
@@ -258,6 +297,8 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
@@ -280,13 +321,17 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
prove-and-verify-evm-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -294,6 +339,8 @@ jobs:
crate: cargo-nextest
locked: true
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -304,7 +351,7 @@ jobs:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -323,12 +370,12 @@ jobs:
cd in-browser-evm-verifier
pnpm build:commonjs
cd ..
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -345,6 +392,8 @@ jobs:
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)
@@ -356,23 +405,67 @@ jobs:
- name: KZG prove and verify tests (EVM + hashed outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
# prove-and-verify-tests-metal:
# permissions:
# contents: read
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@v0.4.0
# with:
# # Pin to version 0.12.1
# version: 'v0.12.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18
# - uses: actions/checkout@v3
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@v2
# with:
# version: 8
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
prove-and-verify-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -429,49 +522,55 @@ jobs:
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
prove-and-verify-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
# prove-and-verify-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
# - uses: actions/checkout@v3
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (kzg outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public inputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# - name: KZG prove and verify tests (fixed params)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# - name: KZG prove and verify tests (hashed outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
prove-and-verify-mock-aggr-tests:
permissions:
contents: read
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -481,32 +580,38 @@ jobs:
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
# prove-and-verify-aggr-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG tests
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -517,34 +622,42 @@ jobs:
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
examples:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -555,42 +668,50 @@ jobs:
run: cargo nextest run --release tests_examples
python-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -601,8 +722,6 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
@@ -613,6 +732,8 @@ jobs:
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
python-integration-tests:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
@@ -634,28 +755,36 @@ jobs:
- 5432:5432
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: Neural bow
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
@@ -671,7 +800,87 @@ jobs:
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Reusable verifier tutorial
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
ios-integration-tests:
permissions:
contents: read
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
permissions:
contents: read
runs-on: macos-latest
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

33
.github/workflows/static-analysis.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Static Analysis
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
analyze:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
# Run Zizmor static analysis
- name: Install Zizmor
run: cargo install --locked zizmor
- name: Run Zizmor Analysis
run: zizmor .

134
.github/workflows/swift-pm.yml vendored Normal file
View File

@@ -0,0 +1,134 @@
name: Build and Publish EZKL iOS SPM package
on:
push:
tags:
# Only support SemVer versioning tags
- 'v[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
build-and-update:
permissions:
contents: read
packages: write
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
RELEASE_TAG: ${{ github.ref_name }}
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${RELEASE_TAG}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
echo "Stripped TAG: $NEW_TAG"
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Check for changes
id: check_changes
run: |
cd ezkl-swift-package
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
echo "no_changes=true" >> $GITHUB_OUTPUT
else
echo "no_changes=false" >> $GITHUB_OUTPUT
fi
- name: Set up Xcode environment
if: steps.check_changes.outputs.no_changes == 'false'
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Setup Git
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
env:
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
- name: Commit and Push Changes
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
fi
- name: Tag the latest commit
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."
else
git tag "$TAG"
fi
if ! git push origin "$TAG"; then
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
exit 1
fi

View File

@@ -12,6 +12,8 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2

7
.gitignore vendored
View File

@@ -27,7 +27,6 @@ __pycache__/
*.pyc
*.pyo
*.py[cod]
bin/
build/
develop-eggs/
dist/
@@ -46,7 +45,7 @@ var/
node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/assets/pk.key
!tests/assets/vk.key
docs/python/build
!tests/wasm/vk_aggr.key
!tests/assets/vk_aggr.key

1657
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ cargo-features = ["profile-rustflags"]
name = "ezkl"
version = "0.0.0"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -11,91 +12,108 @@ edition = "2021"
# Name to be imported within python
# Example: import ezkl
name = "ezkl"
crate-type = ["cdylib", "rlib"]
crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.5.3", features = ["derive"] }
clap_complete = "4.5.2"
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
"raw_value",
], optional = true }
log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
"circuit-params",
] }
rand = { version = "0.8", default-features = false }
itertools = { version = "0.10.3", default-features = false }
clap = { version = "4.5.3", features = ["derive"], optional = true }
serde = { version = "1.0.126", features = ["derive"] }
clap_complete = { version = "4.5.2", optional = true }
log = { version = "0.4.17", default-features = false }
thiserror = { version = "1.0.38", default-features = false }
hex = { version = "0.4.3", default-features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
semver = "1.0.22"
portable-atomic = { version = "1.6.0", optional = true }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
serde_json = { version = "1.0.97", features = ["float_roundtrip", "raw_value"] }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
ethabi = "18"
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
"provider-http",
"signers",
"contract",
"rpc-types-eth",
"signer-wallet",
"node-bindings",
], optional = true }
foundry-compilers = { version = "0.4.1", features = [
"svm-solc",
], optional = true }
ethabi = { version = "18", optional = true }
indicatif = { version = "0.17.5", features = ["rayon"], optional = true }
gag = { version = "1.0.0", default-features = false, optional = true }
instant = { version = "0.1" }
reqwest = { version = "0.12.4", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
futures-util = "0.3.30"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.35", default_features = false, features = [
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
regex = { version = "1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.21.2", features = [
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
], default-features = false, optional = true }
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
uniffi_bindgen = { version = "=0.28.0", optional = true }
camino = { version = "^1.1", optional = true }
uuid = { version = "1.10.0", features = ["v4"], optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true }
env_logger = { version = "0.10.0", default_features = false, optional = true }
chrono = "0.4.31"
sha256 = "1.4.0"
colored = { version = "2.0.0", default-features = false, optional = true }
env_logger = { version = "0.10.0", default-features = false, optional = true }
chrono = { version = "0.4.31", optional = true }
sha256 = { version = "1.4.0", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
serde_json = { version = "1.0.97", default-features = false, features = [
"float_roundtrip",
"raw_value",
] }
getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
@@ -108,8 +126,14 @@ console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
[build-dependencies]
uniffi = { version = "0.28", features = ["build"], optional = true }
[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -122,6 +146,10 @@ shellexpand = "3.1.0"
runner = 'wasm-bindgen-test-runner'
[[bench]]
name = "zero_finder"
harness = false
[[bench]]
name = "accum_dot"
harness = false
@@ -160,16 +188,20 @@ harness = false
[[bench]]
name = "relu"
name = "sigmoid"
harness = false
[[bench]]
name = "accum_matmul_relu"
name = "relu_lookupless"
harness = false
[[bench]]
name = "accum_matmul_sigmoid"
harness = false
[[bench]]
name = "accum_matmul_relu_overflow"
name = "accum_matmul_sigmoid_overflow"
harness = false
[[bin]]
@@ -178,40 +210,94 @@ test = false
bench = false
required-features = ["ezkl"]
[[bin]]
name = "ios_gen_bindings"
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
[[bin]]
name = "py_stub_gen"
required-features = ["python-bindings"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup", "no-banner"]
default = [
"ezkl",
"mv-lookup",
"precompute-coset",
"no-banner",
"parallel-poly-read",
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",
"serde",
"serde_json",
"log",
"colored",
"env_logger",
"dep:colored",
"dep:env_logger",
"tabled/color",
"serde_json/std",
"colored_json",
"dep:alloy",
"dep:foundry-compilers",
"dep:ethabi",
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:openssl",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:regex",
"dep:tokio",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:portable-atomic",
"dep:clap_complete",
"dep:halo2_solidity_verifier",
"dep:semver",
"dep:clap",
"dep:tosubcommand",
]
parallel-poly-read = [
"halo2_proofs/circuit-params",
"halo2_proofs/parallel-poly-read",
]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"
[package.metadata.wasm-pack.profile.release]
wasm-opt = [
"-O4",
"--flexible-inline-max-function-size",
"4294967295",
]

View File

@@ -0,0 +1,147 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "accountCall",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"type": "address"
},
{
"internalType": "bytes",
"name": "callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -1,4 +1,23 @@
[
{
"inputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"name": "check_is_valid_field_element",
"outputs": [
{
"internalType": "uint256[]",
"name": "output",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
@@ -17,12 +36,41 @@
"type": "uint256[]"
}
],
"name": "quantize_data",
"name": "quantize_data_multi",
"outputs": [
{
"internalType": "int64[]",
"internalType": "int256[]",
"name": "quantized_data",
"type": "int64[]"
"type": "int256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "scales",
"type": "uint256[]"
}
],
"name": "quantize_data_single",
"outputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"stateMutability": "pure",

View File

@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
K,
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -75,14 +83,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
k,
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();
MyConfig { base_config }
@@ -76,14 +84,18 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,

View File

@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();

View File

@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();

150
benches/relu_lookupless.rs Normal file
View File

@@ -0,0 +1,150 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut LEN: usize = 4;
const K: usize = 16;
#[derive(Clone)]
struct NLCircuit {
pub input: ValTensor<Fr>,
}
impl Circuit<Fr> for NLCircuit {
type Config = Config<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unsafe {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let mut config = Config::default();
config
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
.unwrap();
config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
config
}
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_range_checks(&mut layouter).unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runrelu(c: &mut Criterion) {
let mut group = c.benchmark_group("relu");
let mut rng = rand::thread_rng();
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 8].iter() {
unsafe {
LEN = len;
};
let input: Tensor<Value<Fr>> =
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runrelu
}
criterion_main!(benches);

View File

@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -41,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let nl = LookupOp::ReLU;
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
let mut config = Config::default();
@@ -62,9 +63,13 @@ impl Circuit<Fr> for NLCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
},
@@ -84,7 +89,7 @@ fn runrelu(c: &mut Criterion) {
};
let input: Tensor<Value<Fr>> =
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),

116
benches/zero_finder.rs Normal file
View File

@@ -0,0 +1,116 @@
use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2curves::{bn256::Fr as F, ff::Field};
use maybe_rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
slice::ParallelSlice,
};
use rand::Rng;
// Assuming these are your types
#[derive(Clone)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),
Other,
}
// Helper to generate test data
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value
}
})
.collect()
}
fn bench_zero_finding(c: &mut Criterion) {
let sizes = [
1_000, // 1K
10_000, // 10K
100_000, // 100K
256 * 256 * 2, // Our specific case
1_000_000, // 1M
10_000_000, // 10M
];
let zero_probability = 0.1; // 10% zeros
let mut group = c.benchmark_group("zero_finding");
group.sample_size(10); // Adjust based on your needs
for &size in &sizes {
let data = generate_test_data(size, zero_probability);
// Benchmark sequential version
group.bench_function(format!("sequential_{}", size), |b| {
b.iter(|| {
let result = data
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark parallel version
group.bench_function(format!("parallel_{}", size), |b| {
b.iter(|| {
let result = data
.par_iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark chunked parallel version
group.bench_function(format!("chunked_parallel_{}", size), |b| {
b.iter(|| {
let num_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (size / num_cores).max(100);
let result = data
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>();
black_box(result)
})
});
}
group.finish();
}
criterion_group!(benches, bench_zero_finding);
criterion_main!(benches);

7
build.rs Normal file
View File

@@ -0,0 +1,7 @@
fn main() {
if cfg!(feature = "ios-bindings-test") {
println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features");
}
println!("cargo::rerun-if-changed=build.rs");
}

View File

@@ -93,9 +93,6 @@ contract LoadInstances {
}
}
// Contract that checks that the COMMITMENT_KZG bytes is equal to the first part of the proof.
pragma solidity ^0.8.0;
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
@@ -163,6 +160,253 @@ contract SwapProofCommitments {
}
return equal; // Return true if the commitment comparison passed
} /// end checkKzgCommits
}
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
bytes callData;
uint256 decimals;
}
AccountCall public accountCall;
uint[] scales;
address public admin;
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_LEN = 0;
uint256 constant OUTPUT_LEN = 0;
uint8 public instanceOffset;
/**
* @dev Initialize the contract with account calls the EZKL model will read from.
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
*/
constructor(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) internal {
AccountCall memory _accountCall = accountCall;
_accountCall.contractAddress = _contractAddresses;
_accountCall.callData = _callData;
_accountCall.decimals = 10 ** _decimals;
accountCall = _accountCall;
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
if (prod1 == 0) {
return prod0 / denominator;
}
require(denominator > prod1, "Math: mulDiv overflow");
uint256 remainder;
assembly {
remainder := mulmod(x, y, denominator)
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
return result;
}
}
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param x - One of the elements of the data returned from the account calls
* @param _decimals - Number of base 10 decimals to scale the data by.
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
*
*/
function quantizeData(
int x,
uint256 _decimals,
uint256 _scale
) internal pure returns (int256 quantized_data) {
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), _scale, _decimals);
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
output += 1;
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
* @param target - The address of the account to make calls to.
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
*/
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
require(
target.code.length > 0,
"Address: call to non-contract"
);
}
return returndata;
} else {
revert("Address: low-level call failed");
}
}
/**
* @dev Convert the fixed point quantized data into a field element.
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
}
/**
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_LEN + OUTPUT_LEN,
"Invalid public inputs length"
);
AccountCall memory _accountCall = accountCall;
uint[] memory _scales = scales;
bytes memory returnData = staticCall(
_accountCall.contractAddress,
_accountCall.callData
);
int256[] memory x = abi.decode(returnData, (int256[]));
uint _offset;
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
uint field_element = toFieldElement(output);
for (uint i = 0; i < x.length; i++) {
if (field_element != instances[i + instanceOffset]) {
_offset += 1;
} else {
break;
}
}
uint length = x.length - _offset;
for (uint i = 1; i < length; i++) {
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
field_element = toFieldElement(output);
require(
field_element == instances[i + instanceOffset + _offset],
"Public input does not match"
);
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}
@@ -176,11 +420,11 @@ contract SwapProofCommitments {
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestation is LoadInstances, SwapProofCommitments {
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -2,8 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
@@ -42,8 +41,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -66,8 +65,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -90,8 +89,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -147,6 +146,8 @@ where
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::configure(
@@ -157,15 +158,11 @@ where
);
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();
layer_config
@@ -196,15 +193,21 @@ where
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config
@@ -221,7 +224,14 @@ where
let x = config
.layer_config
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
let mut x = config
@@ -281,7 +291,7 @@ where
}
pub fn runconv() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
const KERNEL_HEIGHT: usize = 5;
@@ -315,7 +325,11 @@ pub fn runconv() {
.test_set_length(10_000)
.finalize();
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
let mut train_data = Tensor::from(
trn_img
.iter()
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
);
train_data.reshape(&[50_000, 28, 28]).unwrap();
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
@@ -343,8 +357,8 @@ pub fn runconv() {
.map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}),
);
@@ -355,7 +369,8 @@ pub fn runconv() {
let l0_kernels = l0_kernels.try_into().unwrap();
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
let mut l0_bias =
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let l0_bias = l0_bias.try_into().unwrap();
@@ -363,8 +378,8 @@ pub fn runconv() {
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
@@ -374,8 +389,8 @@ pub fn runconv() {
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
}));
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
@@ -401,13 +416,13 @@ pub fn runconv() {
l2_params: [l2_weights, l2_biases],
};
let public_input: Tensor<i32> = vec![
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
let public_input: Tensor<IntegerRep> = vec![
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
]
.into_iter()
.into();
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
println!("MOCK PROVING");
let now = Instant::now();

View File

@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils::i32_to_felt;
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
let output = VarTensor::new_advice(cs, K, 1, LEN);
// tells the config layer to add an affine op to the circuit gate
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::<F>::configure(
cs,
&[input.clone(), params.clone()],
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
CheckMode::SAFE,
);
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
@@ -104,11 +103,16 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let x = config
.layer_config
.layout(
@@ -141,7 +145,14 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
println!("x shape: {:?}", x.dims());
let mut x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap()
.unwrap();
println!("3");
@@ -177,7 +188,14 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
println!("x shape: {:?}", x.dims());
let x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();
println!("6");
println!("offset: {}", region.row());
@@ -212,36 +230,36 @@ impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
}
pub fn runmlp() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
// parameters
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
&[4, 4],
)
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
// input data, with 1 padding to allow for bias
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
.unwrap()
.into();
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
.map(integer_rep_to_felt);
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
let circuit = MyCircuit::<4, -8192, 8192> {
@@ -251,12 +269,12 @@ pub fn runmlp() {
_marker: PhantomData,
};
let public_input: Vec<i32> = unsafe {
let public_input: Vec<IntegerRep> = unsafe {
vec![
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
(2849f32 / 128f32).to_int_unchecked::<i32>(),
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
]
};
@@ -265,7 +283,10 @@ pub fn runmlp() {
let prover = MockProver::run(
K as u32,
&circuit,
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
vec![public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect()],
)
.unwrap();
prover.assert_satisfied();

View File

@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -648,10 +648,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,279 +1,284 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@@ -1,456 +1,459 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
"nbformat": 4,
"nbformat_minor": 0
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,766 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
"\n",
"1 - NBoW\n",
"\n",
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Preparing Data\n",
"\n",
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
"\n",
"The steps to take are:\n",
"\n",
" 1. importing modules\n",
" 2. loading data\n",
" 3. tokenizing data\n",
" 4. creating data splits\n",
" 5. creating a vocabulary\n",
" 6. numericalizing data\n",
" 7. creating the data loaders\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install torchtex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchtext\n",
"import tqdm\n",
"\n",
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
"\n",
"seed = 1234\n",
"\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.features\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
"\n",
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_example(example, tokenizer, max_length):\n",
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
" return {\"tokens\": tokens}\n",
"\n",
"\n",
"max_length = 256\n",
"\n",
"train_data = train_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"test_data = test_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"\n",
"\n",
"# create validation data \n",
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
"\n",
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data[\"train\"]\n",
"valid_data = train_valid_data[\"test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
"\n",
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
" train_data[\"tokens\"],\n",
" min_freq=min_freq,\n",
" specials=special_tokens,\n",
")\n",
"\n",
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
"\n",
"unk_index = vocab[\"<unk>\"]\n",
"pad_index = vocab[\"<pad>\"]\n",
"\n",
"\n",
"vocab.set_default_index(unk_index)\n",
"\n",
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def numericalize_example(example, vocab):\n",
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
" return {\"ids\": ids}\n",
"\n",
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"\n",
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
"\n",
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
"\n",
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
"\n",
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
"\n",
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
"\n",
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_collate_fn(pad_index):\n",
" def collate_fn(batch):\n",
" batch_ids = [i[\"ids\"] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(\n",
" batch_ids, padding_value=pad_index, batch_first=True\n",
" )\n",
" batch_label = [i[\"label\"] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
" return batch\n",
"\n",
" return collate_fn\n",
"\n",
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
" collate_fn = get_collate_fn(pad_index)\n",
" data_loader = torch.utils.data.DataLoader(\n",
" dataset=dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_fn,\n",
" shuffle=shuffle,\n",
" )\n",
" return data_loader\n",
"\n",
"\n",
"batch_size = 512\n",
"\n",
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction\n",
"\n",
"\n",
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique(\"label\"))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
"\n",
"vectors = torchtext.vocab.GloVe()\n",
"\n",
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(data_loader, model, criterion, optimizer, device):\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(data_loader, model, criterion, device):\n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" with torch.no_grad():\n",
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float(\"inf\")\n",
"\n",
"metrics = collections.defaultdict(list)\n",
"\n",
"for epoch in range(n_epochs):\n",
" train_loss, train_acc = train(\n",
" train_data_loader, model, criterion, optimizer, device\n",
" )\n",
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
" metrics[\"train_losses\"].append(train_loss)\n",
" metrics[\"train_accs\"].append(train_acc)\n",
" metrics[\"valid_losses\"].append(valid_loss)\n",
" metrics[\"valid_accs\"].append(valid_acc)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), \"nbow.pt\")\n",
" print(f\"epoch: {epoch}\")\n",
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
"\n",
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
"\n",
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def text_to_tensor(text, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" return tensor\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we do onnx stuff to get the data ready for the zk-circuit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json\n",
"\n",
"text = \"This film is terrible!\"\n",
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"model.to('cpu')\n",
"\n",
"model_path = \"network.onnx\"\n",
"data_path = \"input.json\"\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data_json = dict(input_data = [data_array])\n",
"\n",
"print(data_json)\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data_json, open(data_path, 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.logrows = 23\n",
"run_args.scale_rebase_multiplier = 10\n",
"# inputs should be auditable by all\n",
"run_args.input_visibility = \"public\"\n",
"# same with outputs\n",
"run_args.output_visibility = \"public\"\n",
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
"run_args.param_visibility = \"fixed\"\n",
"\n",
"\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(py_run_args=run_args)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file\n",
"res = await ezkl.gen_witness()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.mock()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup()\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"res = ezkl.prove(proof_path=\"proof.json\")\n",
"\n",
"print(res)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also verify it on chain by creating an onchain verifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
" !solc-select install 0.8.20\n",
" !solc-select use 0.8.20\n",
" !solc --version\n",
" import os\n",
"\n",
"# rely on local installation if the notebook is not in colab\n",
"except:\n",
" import os\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.create_evm_verifier()\n",
"assert res == True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
"\n",
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
"\n",
"Create a new file within remix and copy the verifier code over.\n",
"\n",
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
"\n",
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -232,7 +232,7 @@
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n",
"run_args.logrows = 15\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
]
@@ -404,7 +404,7 @@
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n"
"run_args.logrows = 15\n"
]
},
{
@@ -466,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -0,0 +1,339 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reusable Verifiers \n",
"\n",
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
"\n",
"- `1l_mlp sigmoid`\n",
"- `1l_mlp relu`\n",
"- `1l_conv sigmoid`\n",
"- `1l_conv relu`\n",
"\n",
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
"\n",
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we deploy a unique and much smaller verifying key artifact (VKA) contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.onnx\n",
"\n",
"# Define the models\n",
"class MLP_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Sigmoid, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class MLP_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Relu, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"class Conv_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Sigmoid, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class Conv_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Relu, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"# Instantiate the models\n",
"mlp_sigmoid = MLP_Sigmoid()\n",
"mlp_relu = MLP_Relu()\n",
"conv_sigmoid = Conv_Sigmoid()\n",
"conv_relu = Conv_Relu()\n",
"\n",
"# Dummy input tensor for mlp\n",
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
"input_mlp_path = 'mlp_input.json'\n",
"\n",
"# Dummy input tensor for conv\n",
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
"input_conv_path = 'conv_input.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import ezkl\n",
"\n",
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
" # Create a new directory for the model if it doesn't exist\n",
" if not os.path.exists(name):\n",
" os.mkdir(name)\n",
" # Store the paths in each of their respective directories\n",
" model_path = os.path.join(name, \"network.onnx\")\n",
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
" pk_path = os.path.join(name, \"test.pk\")\n",
" vk_path = os.path.join(name, \"test.vk\")\n",
" settings_path = os.path.join(name, \"settings.json\")\n",
"\n",
" witness_path = os.path.join(name, \"witness.json\")\n",
" sol_code_path = os.path.join(name, 'test.sol')\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" abi_path = os.path.join(name, 'test.abi')\n",
" proof_path = os.path.join(name, \"proof.json\")\n",
"\n",
" # Flips the neural net into inference mode\n",
" model.eval()\n",
"\n",
" # Export the model\n",
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
" do_constant_folding=True, input_names=['input'],\n",
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
" data = dict(input_data=[data_array])\n",
" json.dump(data, open(input_path, 'w'))\n",
"\n",
" py_run_args = ezkl.PyRunArgs()\n",
" py_run_args.input_visibility = \"private\"\n",
" py_run_args.output_visibility = \"public\"\n",
" py_run_args.param_visibility = \"fixed\" # private by default\n",
"\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
" assert res == True\n",
"\n",
" await ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
"\n",
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
" assert res == True\n",
"\n",
" res = await ezkl.get_srs(settings_path)\n",
" assert res == True\n",
"\n",
" # now generate the witness file\n",
" res = await ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
" assert os.path.isfile(witness_path) == True\n",
"\n",
" # SETUP \n",
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" assert os.path.isfile(settings_path)\n",
"\n",
" # GENERATE A PROOF\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
" assert res == True\n",
"\n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, sol_key_code_path, abi_path)\n",
" assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"start_anvil()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import filecmp\n",
"\n",
"def compare_files(file1, file2):\n",
" return filecmp.cmp(file1, file2, shallow=False)\n",
"\n",
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
"\n",
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
"\n",
"\n",
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we deploy separate verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(addr_path_verifier, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
" addr_vk = file.read().rstrip()\n",
" \n",
" proof_path = os.path.join(name, \"proof.json\")\n",
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" addr_vk = addr_vk\n",
" )\n",
" assert res == True"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "171702d3",
"metadata": {},
"outputs": [],
@@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "671dfdd5",
"metadata": {},
"outputs": [],
@@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "50eba2f4",
"metadata": {},
"outputs": [],
@@ -399,9 +399,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -39,7 +39,7 @@
"import json\n",
"import numpy as np\n",
"from sklearn.svm import SVC\n",
"import sk2torch\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
@@ -59,11 +59,11 @@
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = SVC(probability=True)\n",
"sk_model.fit(xs, ys)\n",
"model = sk2torch.wrap(sk_model)\n",
"\n",
"model = convert(sk_model, \"torch\").model\n",
"\n",
"\n",
"\n",
"model\n",
"\n"
]
},
@@ -84,33 +84,6 @@
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ca328",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"\n",
"# Compute the gradients of the SVM output.\n",
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
"\n",
"\n",
"# Create a quiver plot of the vector field.\n",
"plt.quiver(\n",
" grid_xs[:, 0].detach().numpy(),\n",
" grid_xs[:, 1].detach().numpy(),\n",
" input_grads[:, 0].detach().numpy(),\n",
" input_grads[:, 1].detach().numpy(),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -119,14 +92,14 @@
"outputs": [],
"source": [
"\n",
"\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
@@ -143,9 +116,7 @@
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"data = dict(input_data=[d])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -167,6 +138,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0bee4d7f",
"metadata": {},
"outputs": [],
"source": [
@@ -220,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
@@ -441,9 +413,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -0,0 +1,763 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# univ3-da-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source. For this setup we make a single call to a view function that returns an array of UniV3 historical TWAP price data that we will attest to on-chain. \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--fork-url\", \"https://arb1.arbitrum.io/rpc\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"import requests\n",
"\n",
"# This function counts the decimal places of a floating point number\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
" return len(num_str) - 1 - num_str.index('.')\n",
" else:\n",
" return 0\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"def set_next_block_timestamp(anvil_url, timestamp):\n",
" # Send the JSON-RPC request to Anvil\n",
" payload = {\n",
" \"jsonrpc\": \"2.0\",\n",
" \"id\": 1,\n",
" \"method\": \"evm_setNextBlockTimestamp\",\n",
" \"params\": [timestamp]\n",
" }\n",
" response = requests.post(anvil_url, json=payload)\n",
" if response.status_code == 200:\n",
" print(f\"Next block timestamp set to: {timestamp}\")\n",
" else:\n",
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
"\n",
"def on_chain_data(tensor):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = tensor.view(-1).tolist()\n",
"\n",
" # Step 1: Prepare the calldata\n",
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
"\n",
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: MIT\n",
" pragma solidity ^0.8.20;\n",
"\n",
" /// @title Pool state that is not stored\n",
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
" /// blockchain. The functions here may have variable gas costs.\n",
" interface IUniswapV3PoolDerivedState {\n",
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
" /// you must call it with secondsAgos = [3600, 0].\n",
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
" /// timestamp\n",
" function observe(\n",
" uint32[] calldata secondsAgos\n",
" )\n",
" external\n",
" view\n",
" returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" }\n",
"\n",
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
" /// @notice Provides functions to integrate with V3 pool oracle\n",
" contract UniTickAttestor {\n",
" /**\n",
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
" * @param pool Address of the pool that we want to observe\n",
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
" */\n",
" function consult(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public view returns (int256[] memory tickCumulatives) {\n",
" tickCumulatives = new int256[](secondsAgo.length);\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" for (uint256 i = 0; i < secondsAgo.length; i++) {\n",
" tickCumulatives[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"UniTickAttestor.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = UniTickAttestor.constructor().transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" call = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).build_transaction()\n",
" result = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).call()\n",
" \n",
" print(f'result: {result}')\n",
" calldata = call['data'][2:]\n",
"\n",
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
" print(f'time_stamp: {time_stamp}')\n",
"\n",
" # Set the next block timestamp using the fetched time_stamp\n",
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
"\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': 0,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" 'len': len(data),\n",
" }\n",
"\n",
" print(f'call_to_account: {call_to_account}')\n",
"\n",
" return call_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we need to regenerate the witness, prove and then verify all within the same cell. This is because we want to reduce the amount of latency between reading on-chain state and verifying it on-chain. This is because the attest input values read from the oracle are time sensitive (their values are derived from computing on block.timestamp) and can change between the time of reading and the time of verifying.\n",
"\n",
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
"print(f'time_stamp: {time_stamp}')\n",
"\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)\n",
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -9,7 +9,9 @@ class MyModel(nn.Module):
super(MyModel, self).__init__()
def forward(self, w, x, y, z):
return [((x & y)) == (x & (y | (z ^ w)))]
a = (x & y)
b = (y & (z ^ w))
return [a & b]
circuit = MyModel()

View File

@@ -1 +1 @@
{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]}
{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]}

View File

@@ -1,21 +1,17 @@
pytorch1.12.1:«
+
pytorch2.2.2:„
*
input1
input2
onnx::Equal_4And_0"And
input2
/And_output_0/And"And
)
input3
input
onnx::Or_5Xor_1"Xor
input3
input
/Xor_output_0/Xor"Xor
input2
onnx::Or_5 onnx::And_6Or_2"Or
0
input1
onnx::And_6
onnx::Equal_7And_3"And
6
5
input2
/Xor_output_0/And_1_output_0/And_1"And
5
/And_output_0
/And_1_output_0output/And_2"And

42
examples/onnx/exp/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.exp(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.5801457762718201, 0.6019012331962585, 0.8695418238639832, 0.17170941829681396, 0.500616729259491, 0.353726327419281, 0.6726185083389282, 0.5936906337738037]]}

View File

@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Exp"Exp
main_graphZ!
input


batch_size
b"
output


batch_size
B

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,41 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = 10**x
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.9837989807128906, 0.026381194591522217, 0.3403851389884949, 0.14531707763671875, 0.24652725458145142, 0.7945117354393005, 0.4076554775238037, 0.23064672946929932]]}

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

42
examples/onnx/log/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.log(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 3)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}

View File

@@ -0,0 +1,14 @@
pytorch2.2.2:o

inputoutput/Log"Log
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -21,9 +21,9 @@ def main():
torch_model = Circuit()
# Input to the model
shape = [3, 2, 3]
w = 0.1*torch.rand(1, *shape, requires_grad=True)
x = 0.1*torch.rand(1, *shape, requires_grad=True)
y = 0.1*torch.rand(1, *shape, requires_grad=True)
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
torch_out = torch_model(w, x, y)
# Export the model
torch.onnx.export(torch_model, # model being run

View File

@@ -1 +1,148 @@
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
{
"input_shapes": [
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
]
],
"input_data": [
[
0.5,
1.5,
-0.04514765739440918,
0.5936200618743896,
0.9271858930587769,
0.6688600778579712,
-0.20331168174743652,
-0.7016235589981079,
0.025863051414489746,
-0.19426143169403076,
0.9827852249145508,
0.4897397756576538,
-1.5,
-0.5,
0.9278832674026489,
0.5943725109100342,
-0.573331356048584,
0.3675816059112549
],
[
0.7803324460983276,
-0.9616303443908691,
0.6070173978805542,
-0.028337717056274414,
-0.5080242156982422,
-0.9280107021331787,
0.6150380373001099,
0.3865993022918701,
-0.43668973445892334,
0.17152702808380127,
0.5144252777099609,
-0.28881049156188965,
0.8932310342788696,
0.059034109115600586,
0.6865451335906982,
0.009820222854614258,
0.23011493682861328,
-0.9492779970169067
],
[
-0.21352827548980713,
-0.16015326976776123,
-0.38964390754699707,
0.13464701175689697,
-0.8814496994018555,
0.5037975311279297,
-0.804405927658081,
0.9858957529067993,
0.19567716121673584,
0.9777265787124634,
0.6151977777481079,
0.568595290184021,
0.10584986209869385,
-0.8975653648376465,
0.6235959529876709,
-0.547879695892334,
0.9289869070053101,
0.7567293643951416
]
],
"output_data": [
[
1.0,
0.0,
-0.0,
1.0,
1.0,
1.0,
-0.0,
-1.0,
0.0,
-0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
-1.0,
0.0
],
[
0.0,
-1.0,
0.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-1.0
],
[
-0.0,
-0.0,
-0.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0
]
]
}

View File

@@ -1,10 +1,11 @@
pytorch2.0.1:â
pytorch2.2.2:ă

woutput_w/Round"Round

xoutput_x/Floor"Floor

youtput_y/Ceil"Ceil torch_jitZ%
youtput_y/Ceil"Ceil
main_graphZ%
w



View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
# reciprocal sqrt
m = 1 / torch.sqrt(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}

View File

@@ -0,0 +1,17 @@
pytorch2.2.2:Ź
$
input/Sqrt_output_0/Sqrt"Sqrt
1
/Sqrt_output_0output /Reciprocal"
Reciprocal
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -1,75 +1,52 @@
import torch
import torch.nn as nn
import sys
from torch import nn
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
import numpy as np
import tf2onnx
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
# gather_nd in tf then export to onnx
x = in1 = Input((4, 1), dtype=tf.int32)
w = in2 = Input((4, ), dtype=tf.int32)
class MyLayer(Layer):
def call(self, x, w):
shape = tf.constant([8])
return tf.scatter_nd(x, w, shape)
x = MyLayer()(x, w)
tm = Model((in1, in2), x)
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 4, 1]
index_shape = [1, 4]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = np.random.randint(0, 4, shape)
# w = random int tensor
w = np.random.randint(0, 4, index_shape)
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d1],
input_data=[d, d1],
)
# Serialize data into file:

View File

@@ -1 +1,16 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
{
"input_data": [
[
0,
1,
2,
3
],
[
1,
0,
2,
1
]
]
}

View File

@@ -0,0 +1 @@
network.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,47 @@
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
cp input.json calibration.json
ezkl calibrate-settings
ezkl compile-circuit
ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
size 82095882

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

851
ezkl.pyi Normal file
View File

@@ -0,0 +1,851 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401
import os
import pathlib
import typing
from enum import Enum, auto
class PyG1:
r"""
pyclass containing the struct used for G1, this is mostly a helper class
"""
...
class PyG1Affine:
r"""
pyclass containing the struct used for G1
"""
...
class PyRunArgs:
r"""
Python class containing the struct used for run_args
Returns
-------
PyRunArgs
"""
...
class PyCommitments(Enum):
r"""
pyclass representing an enum, denoting the type of commitment
"""
KZG = auto()
IPA = auto()
class PyInputType(Enum):
Bool = auto()
F16 = auto()
F32 = auto()
F64 = auto()
Int = auto()
TDim = auto()
class PyTestDataSource(Enum):
r"""
pyclass representing an enum
"""
File = auto()
OnChain = auto()
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
r"""
Creates an aggregated proof
Arguments
---------
aggregation_snarks: list[str]
List of paths to the various proofs
proof_path: str
Path to output the aggregated proof
vk_path: str
Path to the VK file
transcript:
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
logrows:
Logrows used for aggregation circuit
check_mode: str
Run sanity checks during calculations. Accepts `safe` or `unsafe`
split-proofs: bool
Whether the accumulated proofs are segments of a larger circuit
srs_path: str
Path to the SRS used
commitment: str
Accepts "kzg" or "ipa"
Returns
-------
bool
"""
...
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
r"""
Converts a buffer to vector of field elements
Arguments
-------
buffer: list[int]
List of integers representing a buffer
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
r"""
Calibrates the circuit settings
Arguments
---------
data: str
Path to the calibration data
model: str
Path to the onnx file
settings: str
Path to the settings file
lookup_safety_margin: int
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
scales: list[int]
Optional scales to specifically try for calibration
scale_rebase_multiplier: list[int]
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
max_logrows: int
Optional max logrows to use for calibration
Returns
-------
bool
"""
...
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Compiles the circuit for use in other steps
Arguments
---------
model: str
Path to the onnx model file
compiled_circuit: str
Path to output the compiled circuit
settings_path: str
Path to the settings files
Returns
-------
bool
"""
...
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
Arguments
---------
input_data: str
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
Returns
-------
bool
"""
...
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
Arguments
---------
aggregation_settings: str
path to the settings file
vk_path: str
The path to load the desired verification key file
sol_code_path: str
The path to the Solidity code
abi_path: str
The path to output the Solidity verifier ABI
logrows: int
Number of logrows used during aggregated setup
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifying key.
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity da verifier
"""
...
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity verifier
"""
...
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
r"""
Creates encoded evm calldata from a proof file
Arguments
---------
proof: str
Path to the proof file
calldata: str
Path to the calldata file to save
addr_vk: str
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
Returns
-------
vec[u8]
The encoded calldata
"""
...
def felt_to_big_endian(felt:str) -> str:
r"""
Converts a field element hex string to big endian
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
str
field element represented as a string
"""
...
def felt_to_float(felt:str,scale:int) -> float:
r"""
Converts a field element hex string to a floating point number
Arguments
-------
felt: str
The field element represented as a string
scale: float
The scaling factor used to convert the field element into a floating point representation
Returns
-------
float
"""
...
def felt_to_int(felt:str) -> int:
r"""
Converts a field element hex string to an integer
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
int
"""
...
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
r"""
Converts a floating point element to a field element hex string
Arguments
-------
input: float
The field element represented as a string
scale: float
The scaling factor used to quantize the float into a field element
input_type: PyInputType
The type of the input
Returns
-------
str
The field element represented as a string
"""
...
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
r"""
Generates the circuit settings
Arguments
---------
model: str
Path to the onnx file
output: str
Path to create the settings file
py_run_args: PyRunArgs
PyRunArgs object to initialize the settings
Returns
-------
bool
"""
...
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
r"""
Generates the Structured Reference String (SRS), use this only for testing purposes
Arguments
---------
srs_path: str
Path to the create the SRS file
logrows: int
The number of logrows for the SRS file
"""
...
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for an aggregate circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for a model circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
circuit_settings_path: str
Path to the witness file
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the forward pass operation to generate a witness
Arguments
---------
data: str
Path to the data file
model: str
Path to the compiled model file
output: str
Path to create the witness file
vk_path: str
Path to the verification key
srs_path: str
Path to the SRS file
Returns
-------
dict
Python object containing the witness values
"""
...
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
r"""
Gets a public srs
Arguments
---------
settings_path: str
Path to the settings file
logrows: int
The number of logrows for the SRS file
srs_path: str
Path to the create the SRS file
commitment: str
Specify the commitment used ("kzg", "ipa")
Returns
-------
bool
"""
...
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate an ipa commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate a kzg commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
r"""
Mocks the prover
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
Returns
-------
bool
"""
...
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
r"""
Mocks the aggregate prover
Arguments
---------
aggregation_snarks: list[str]
List of paths to the relevant proof files
logrows: int
Number of logrows to use for the aggregation circuit
split_proofs: bool
Indicates whether the accumulated are segments of a larger proof
Returns
-------
bool
"""
...
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
r"""
Generate a poseidon hash.
Arguments
-------
message: list[str]
List of field elements represented as strings
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the prover on a set of inputs
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
pk_path: str
Path to the proving key file
proof_path: str
Path to create the proof file
proof_type: str
Accepts `single`, `for-aggr`
srs_path: str
Path to the SRS file
Returns
-------
bool
"""
...
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
r"""
Runs the setup process
Arguments
---------
model: str
Path to the compiled model file
vk_path: str
Path to create the verification key file
pk_path: str
Path to create the proving key file
srs_path: str
Path to the SRS file
witness_path: str
Path to the witness file
disable_selector_compression: bool
Whether to compress the selectors or not
Returns
-------
bool
"""
...
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
r"""
Runs the setup process for an aggregate setup
Arguments
---------
sample_snarks: list[str]
List of paths to the various proofs
vk_path: str
Path to create the aggregated VK
pk_path: str
Path to create the aggregated PK
logrows: int
Number of logrows to use
split_proofs: bool
Whether the accumulated are segments of a larger proof
srs_path: str
Path to the SRS file
disable_selector_compression: bool
Whether to compress selectors
commitment: str
Accepts `kzg`, `ipa`
Returns
-------
bool
"""
...
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
r"""
Setup test evm witness
Arguments
---------
data_path: str
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
compiled_circuit_path: str
The path to the compiled model file (generated using the compile-circuit command)
test_data: str
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
input_sources: str
Where the input data comes from
output_source: str
Where the output data comes from
rpc_url: str
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
Returns
-------
bool
"""
...
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
r"""
Swap the commitments in a proof
Arguments
-------
proof_path: str
Path to the proof file
witness_path: str
Path to the witness file
"""
...
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
r"""
Displays the table as a string in python
Arguments
---------
model: str
Path to the onnx file
Returns
---------
str
Table of the nodes in the onnx file
"""
...
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
r"""
Verifies a given proof
Arguments
---------
proof_path: str
Path to create the proof file
settings_path: str
Path to the settings file
vk_path: str
Path to the verification key file
srs_path: str
Path to the SRS file
non_reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
Returns
-------
bool
"""
...
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
r"""
Verifies and aggregate proof
Arguments
---------
proof_path: str
The path to the proof file
vk_path: str
The path to the verification key file
logrows: int
logrows used for aggregation circuit
commitment: str
Accepts "kzg" or "ipa"
reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
r"""
verifies an evm compatible proof, you will need solc installed in your environment to run this
Arguments
---------
addr_verifier: str
The verifier contract's address as a hex string
proof_path: str
The path to the proof file (generated using the prove command)
rpc_url: str
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
addr_da: str
does the verifier use data attestation ?
addr_vk: str
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
Returns
-------
bool
"""
...

View File

@@ -9,7 +9,6 @@ import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
import { error } from 'console'
async function deployContract(
vm: VM,
@@ -66,7 +65,7 @@ async function verify(
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
error('vkAddress', vkAddress)
console.error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)

View File

@@ -12,6 +12,7 @@ asyncio_mode = "auto"
[project]
name = "ezkl"
version = "0.0.0"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-02-06"
channel = "nightly-2024-07-18"
components = ["rustfmt", "clippy"]

View File

@@ -1,25 +1,28 @@
// ignore file if compiling for wasm
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored_json::ToColoredJson;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::commands::Cli;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::{error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(feature = "icicle")]
use std::env;
#[tokio::main(flavor = "current_thread")]
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
let args = Cli::parse();
@@ -56,7 +59,7 @@ pub async fn main() {
}
}
#[cfg(target_arch = "wasm32")]
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub fn main() {}
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]

269
src/bin/ios_gen_bindings.rs Normal file
View File

@@ -0,0 +1,269 @@
use camino::Utf8Path;
use std::fs;
use std::fs::remove_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
use uniffi_bindgen::bindings::SwiftBindingGenerator;
use uniffi_bindgen::library_mode::generate_bindings;
use uuid::Uuid;
fn main() {
let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set");
let mode = determine_build_mode();
build_bindings(&library_name, mode);
}
/// Determines the build mode based on the CONFIGURATION environment variable.
/// Defaults to "release" if not set or unrecognized.
/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster.
fn determine_build_mode() -> &'static str {
match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) {
Ok(ref config) if config == "debug" => "debug",
_ => "release",
}
}
/// Builds the Swift bindings and XCFramework for the specified library and build mode.
fn build_bindings(library_name: &str, mode: &str) {
// Get the root directory of this Cargo project
let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| std::env::current_dir().unwrap());
// Define the build directory inside the manifest directory
let build_dir = manifest_dir.join("build");
// Create a temporary directory to store the bindings and combined library
let tmp_dir = mktemp_local(&build_dir);
// Define directories for Swift bindings and output bindings
let swift_bindings_dir = tmp_dir.join("SwiftBindings");
let bindings_out = create_bindings_out_dir(&tmp_dir);
let framework_out = bindings_out.join("EzklCore.xcframework");
// Define target architectures for building
// We currently only support iOS devices and simulators running on ARM Macs
// This is due to limiting the library size to under 100MB for GitHub Commit Size Limit
// To support older Macs (Intel), follow the instructions in the comments below
#[allow(clippy::useless_vec)]
let target_archs = vec![
vec!["aarch64-apple-ios"], // iOS device
vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac
// vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel)
];
// Build the library for each architecture and combine them
let out_lib_paths: Vec<PathBuf> = target_archs
.iter()
.map(|archs| build_combined_archs(library_name, archs, &build_dir, mode))
.collect();
// Generate the path to the built dynamic library (.dylib)
let out_dylib_path = build_dir.join(format!(
"{}/{}/lib{}.dylib",
target_archs[0][0], mode, library_name
));
// Generate Swift bindings using uniffi_bindgen
generate_ios_bindings(&out_dylib_path, &swift_bindings_dir)
.expect("Failed to generate iOS bindings");
// Move the generated Swift file to the bindings output directory
fs::rename(
swift_bindings_dir.join(format!("{}.swift", library_name)),
bindings_out.join("EzklCore.swift"),
)
.expect("Failed to copy swift bindings file");
// Rename the `ios_ezklFFI.modulemap` file to `module.modulemap`
fs::rename(
swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)),
swift_bindings_dir.join("module.modulemap"),
)
.expect("Failed to rename modulemap file");
// Create the XCFramework from the combined libraries and Swift bindings
create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out);
// Define the destination directory for the bindings
let bindings_dest = build_dir.join("EzklCoreBindings");
if bindings_dest.exists() {
fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory");
}
// Move the bindings output to the destination directory
fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place");
// Clean up temporary directories
cleanup_temp_dirs(&build_dir);
}
/// Creates the output directory for the bindings.
/// Returns the path to the bindings output directory.
fn create_bindings_out_dir(base_dir: &Path) -> PathBuf {
let bindings_out = base_dir.join("EzklCoreBindings");
fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory");
bindings_out
}
/// Builds the library for each architecture and combines them into a single library using lipo.
/// Returns the path to the combined library.
fn build_combined_archs(
library_name: &str,
archs: &[&str],
build_dir: &Path,
mode: &str,
) -> PathBuf {
// Build the library for each architecture
let out_lib_paths: Vec<PathBuf> = archs
.iter()
.map(|&arch| {
build_for_arch(arch, build_dir, mode);
build_dir
.join(arch)
.join(mode)
.join(format!("lib{}.a", library_name))
})
.collect();
// Create a unique temporary directory for the combined library
let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name));
// Combine the libraries using lipo
let mut lipo_cmd = Command::new("lipo");
lipo_cmd
.arg("-create")
.arg("-output")
.arg(lib_out.to_str().unwrap());
for lib_path in &out_lib_paths {
lipo_cmd.arg(lib_path.to_str().unwrap());
}
let status = lipo_cmd.status().expect("Failed to run lipo command");
if !status.success() {
panic!("lipo command failed with status: {}", status);
}
lib_out
}
/// Builds the library for a specific architecture.
fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) {
// Ensure the target architecture is installed
install_arch(arch);
// Run cargo build for the specified architecture and mode
let mut build_cmd = Command::new("cargo");
build_cmd
.arg("build")
.arg("--no-default-features")
.arg("--features")
.arg("ios-bindings");
if mode == "release" {
build_cmd.arg("--release");
}
build_cmd
.arg("--lib")
.env("CARGO_BUILD_TARGET_DIR", build_dir)
.env("CARGO_BUILD_TARGET", arch);
let status = build_cmd.status().expect("Failed to run cargo build");
if !status.success() {
panic!("cargo build failed for architecture: {}", arch);
}
}
/// Installs the specified target architecture using rustup.
fn install_arch(arch: &str) {
let status = Command::new("rustup")
.arg("target")
.arg("add")
.arg(arch)
.status()
.expect("Failed to run rustup command");
if !status.success() {
panic!("Failed to install target architecture: {}", arch);
}
}
/// Generates Swift bindings for the iOS library using uniffi_bindgen.
fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> {
// Remove existing binding directory if it exists
if binding_dir.exists() {
remove_dir_all(binding_dir)?;
}
// Generate the Swift bindings using uniffi_bindgen
generate_bindings(
Utf8Path::from_path(dylib_path).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path")
})?,
None,
&SwiftBindingGenerator,
None,
Utf8Path::from_path(binding_dir).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid Swift bindings directory",
)
})?,
true,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
Ok(())
}
/// Creates an XCFramework from the combined libraries and Swift bindings.
fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) {
let mut xcbuild_cmd = Command::new("xcodebuild");
xcbuild_cmd.arg("-create-xcframework");
// Add each library and its corresponding headers to the xcodebuild command
for lib_path in lib_paths {
println!("Including library: {:?}", lib_path);
xcbuild_cmd.arg("-library");
xcbuild_cmd.arg(lib_path.to_str().unwrap());
xcbuild_cmd.arg("-headers");
xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap());
}
xcbuild_cmd.arg("-output");
xcbuild_cmd.arg(framework_out.to_str().unwrap());
let status = xcbuild_cmd.status().expect("Failed to run xcodebuild");
if !status.success() {
panic!("xcodebuild failed with status: {}", status);
}
}
/// Creates a temporary directory inside the build path with a unique UUID.
/// This ensures unique build artifacts for concurrent builds.
fn mktemp_local(build_path: &Path) -> PathBuf {
let dir = tmp_local(build_path).join(Uuid::new_v4().to_string());
fs::create_dir(&dir).expect("Failed to create temporary directory");
dir
}
/// Gets the path to the local temporary directory inside the build path.
fn tmp_local(build_path: &Path) -> PathBuf {
let tmp_path = build_path.join("tmp");
if let Ok(metadata) = fs::metadata(&tmp_path) {
if !metadata.is_dir() {
panic!("Expected 'tmp' to be a directory");
}
} else {
fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory");
}
tmp_path
}
/// Cleans up temporary directories inside the build path.
fn cleanup_temp_dirs(build_dir: &Path) {
let tmp_dir = build_dir.join("tmp");
if tmp_dir.exists() {
fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories");
}
}

9
src/bin/py_stub_gen.rs Normal file
View File

@@ -0,0 +1,9 @@
use pyo3_stub_gen::Result;
fn main() -> Result<()> {
// `stub_info` is a function defined by `define_stub_info_gatherer!` macro.
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = ezkl::bindings::python::stub_info()?;
stub.generate()?;
Ok(())
}

12
src/bindings/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
/// Python bindings
#[cfg(feature = "python-bindings")]
pub mod python;
/// Universal bindings for all platforms
#[cfg(any(
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown")
))]
pub mod universal;
/// wasm prover and verifier
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;

View File

@@ -4,9 +4,10 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_i64, i64_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::TestDataSource;
use crate::graph::{
@@ -26,7 +27,12 @@ use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
use std::str::FromStr;
use std::{fs::File, path::PathBuf};
@@ -35,6 +41,7 @@ type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyTestDataSource {
/// The data is loaded from a file
File,
@@ -54,6 +61,7 @@ impl From<PyTestDataSource> for TestDataSource {
/// pyclass containing the struct used for G1, this is mostly a helper class
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
@@ -100,6 +108,7 @@ impl pyo3::ToPyObject for PyG1 {
/// pyclass containing the struct used for G1
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
pub struct PyG1Affine {
#[pyo3(get, set)]
///
@@ -145,6 +154,7 @@ impl pyo3::ToPyObject for PyG1Affine {
///
#[pyclass]
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
@@ -180,9 +190,6 @@ struct PyRunArgs {
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
@@ -191,6 +198,15 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
/// int: The base used for decomposition
#[pyo3(get, set)]
pub decomp_base: usize,
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
}
/// default instantiation of PyRunArgs
@@ -206,6 +222,7 @@ impl PyRunArgs {
impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
@@ -217,10 +234,11 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
}
}
}
@@ -228,6 +246,7 @@ impl From<PyRunArgs> for RunArgs {
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
@@ -239,16 +258,18 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
}
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
@@ -296,6 +317,65 @@ impl FromStr for PyCommitments {
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyInputType {
///
Bool,
///
F16,
///
F32,
///
F64,
///
Int,
///
TDim,
}
impl From<InputType> for PyInputType {
fn from(input_type: InputType) -> Self {
match input_type {
InputType::Bool => PyInputType::Bool,
InputType::F16 => PyInputType::F16,
InputType::F32 => PyInputType::F32,
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
}
}
}
impl From<PyInputType> for InputType {
fn from(py_input_type: PyInputType) -> Self {
match py_input_type {
PyInputType::Bool => InputType::Bool,
PyInputType::F16 => InputType::F16,
PyInputType::F32 => InputType::F32,
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
}
}
}
impl FromStr for PyInputType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bool" => Ok(PyInputType::Bool),
"f16" => Ok(PyInputType::F16),
"f32" => Ok(PyInputType::F32),
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
_ => Err("Invalid value for InputType".to_string()),
}
}
}
/// Converts a field element hex string to big endian
///
/// Arguments
@@ -312,6 +392,7 @@ impl FromStr for PyCommitments {
#[pyfunction(signature = (
felt,
))]
#[gen_stub_pyfunction]
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
@@ -331,9 +412,10 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
#[pyfunction(signature = (
felt,
))]
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
#[gen_stub_pyfunction]
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i64(felt);
let int_rep = felt_to_integer_rep(felt);
Ok(int_rep)
}
@@ -355,9 +437,10 @@ fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
felt,
scale
))]
#[gen_stub_pyfunction]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i64(felt);
let int_rep = felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
@@ -373,6 +456,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// input_type: PyInputType
/// The type of the input
///
/// Returns
/// -------
/// str
@@ -380,12 +466,15 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
///
#[pyfunction(signature = (
input,
scale
scale,
input_type=PyInputType::F64
))]
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
#[gen_stub_pyfunction]
fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult<PyFelt> {
InputType::roundtrip(&input_type.into(), &mut input);
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = i64_to_felt(int_rep);
let felt = integer_rep_to_felt(int_rep);
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}
@@ -404,6 +493,7 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
#[pyfunction(signature = (
buffer
))]
#[gen_stub_pyfunction]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -476,6 +566,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
#[pyfunction(signature = (
message,
))]
#[gen_stub_pyfunction]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
@@ -521,6 +612,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -579,6 +671,7 @@ fn kzg_commit(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -625,6 +718,7 @@ fn ipa_commit(
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
#[gen_stub_pyfunction]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -654,6 +748,7 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -691,6 +786,7 @@ fn gen_vk_from_pk_single(
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -720,6 +816,7 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
))]
#[gen_stub_pyfunction]
fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?;
@@ -745,6 +842,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
srs_path,
logrows,
))]
#[gen_stub_pyfunction]
fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
let params = ezkl_gen_srs::<KZGCommitmentScheme<Bn256>>(logrows as u32);
save_params::<KZGCommitmentScheme<Bn256>>(&srs_path, &params)?;
@@ -777,6 +875,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
srs_path=None,
commitment=None,
))]
#[gen_stub_pyfunction]
fn get_srs(
py: Python,
settings_path: Option<PathBuf>,
@@ -789,7 +888,7 @@ fn get_srs(
None => None,
};
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
.await
.map_err(|e| {
@@ -823,6 +922,7 @@ fn get_srs(
output=PathBuf::from(DEFAULT_SETTINGS),
py_run_args = None,
))]
#[gen_stub_pyfunction]
fn gen_settings(
model: PathBuf,
output: PathBuf,
@@ -838,6 +938,45 @@ fn gen_settings(
Ok(true)
}
/// Generates random data for the model
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the data file
///
/// seed: int
/// Random seed to use for generated data
///
/// variables
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn gen_random_data(
model: PathBuf,
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Calibrates the circuit settings
///
/// Arguments
@@ -863,8 +1002,6 @@ fn gen_settings(
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
@@ -879,21 +1016,20 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: PathBuf,
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: i64,
lookup_safety_margin: f64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::calibrate(
model,
data,
@@ -902,7 +1038,6 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.await
@@ -946,6 +1081,7 @@ fn calibrate_settings(
vk_path=None,
srs_path=None,
))]
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: PathBuf,
@@ -954,7 +1090,7 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
.await
.map_err(|e| {
@@ -983,6 +1119,7 @@ fn gen_witness(
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
))]
#[gen_stub_pyfunction]
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
crate::execute::mock(model, witness).map_err(|e| {
let err_str = format!("Failed to run mock: {}", e);
@@ -1013,6 +1150,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
))]
#[gen_stub_pyfunction]
fn mock_aggregate(
aggregation_snarks: Vec<PathBuf>,
logrows: u32,
@@ -1060,6 +1198,7 @@ fn mock_aggregate(
witness_path = None,
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup(
model: PathBuf,
vk_path: PathBuf,
@@ -1118,6 +1257,7 @@ fn setup(
proof_type=ProofType::default(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn prove(
witness: PathBuf,
model: PathBuf,
@@ -1173,6 +1313,7 @@ fn prove(
srs_path=None,
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
#[gen_stub_pyfunction]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
@@ -1232,6 +1373,7 @@ fn verify(
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,
@@ -1282,6 +1424,7 @@ fn setup_aggregate(
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
))]
#[gen_stub_pyfunction]
fn compile_circuit(
model: PathBuf,
compiled_circuit: PathBuf,
@@ -1341,6 +1484,7 @@ fn compile_circuit(
srs_path=None,
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn aggregate(
aggregation_snarks: Vec<PathBuf>,
proof_path: PathBuf,
@@ -1406,6 +1550,7 @@ fn aggregate(
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn verify_aggr(
proof_path: PathBuf,
vk_path: PathBuf,
@@ -1453,6 +1598,7 @@ fn verify_aggr(
calldata=PathBuf::from(DEFAULT_CALLDATA),
addr_vk=None,
))]
#[gen_stub_pyfunction]
fn encode_evm_calldata<'a>(
proof: PathBuf,
calldata: PathBuf,
@@ -1490,8 +1636,8 @@ fn encode_evm_calldata<'a>(
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
/// reusable: bool
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
///
/// Returns
/// -------
@@ -1503,8 +1649,9 @@ fn encode_evm_calldata<'a>(
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier(
py: Python,
vk_path: PathBuf,
@@ -1512,16 +1659,16 @@ fn create_evm_verifier(
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_verifier(
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
reusable,
)
.await
.map_err(|e| {
@@ -1533,6 +1680,58 @@ fn create_evm_verifier(
})
}
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifying key.
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
#[gen_stub_pyfunction]
fn create_evm_vka(
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
@@ -1560,6 +1759,7 @@ fn create_evm_verifier(
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
witness_path=None,
))]
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: PathBuf,
@@ -1568,7 +1768,7 @@ fn create_evm_data_attestation(
abi_path: PathBuf,
witness_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_data_attestation(
settings_path,
sol_code_path,
@@ -1620,6 +1820,7 @@ fn create_evm_data_attestation(
output_source,
rpc_url=None,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: PathBuf,
@@ -1629,7 +1830,7 @@ fn setup_test_evm_witness(
output_source: PyTestDataSource,
rpc_url: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_witness(
data_path,
compiled_circuit_path,
@@ -1653,60 +1854,28 @@ fn setup_test_evm_witness(
addr_path,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
#[gen_stub_pyfunction]
fn deploy_evm(
py: Python,
addr_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2Verifier",
)
.await
.map_err(|e| {
let err_str = format!("Failed to run deploy_evm: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// deploys the solidity vk verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
fn deploy_vk_evm(
py: Python,
addr_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2VerifyingKey",
contract_type,
)
.await
.map_err(|e| {
@@ -1728,6 +1897,7 @@ fn deploy_vk_evm(
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
#[gen_stub_pyfunction]
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
@@ -1738,7 +1908,7 @@ fn deploy_da_evm(
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::deploy_da_evm(
input_data,
settings_path,
@@ -1762,7 +1932,7 @@ fn deploy_da_evm(
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
/// The verifier contract's address as a hex string
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
@@ -1774,7 +1944,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool
@@ -1786,6 +1956,7 @@ fn deploy_da_evm(
addr_da = None,
addr_vk = None,
))]
#[gen_stub_pyfunction]
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
@@ -1808,7 +1979,7 @@ fn verify_evm<'a>(
None
};
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
.await
.map_err(|e| {
@@ -1842,8 +2013,8 @@ fn verify_evm<'a>(
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
/// reusable: bool
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
///
/// Returns
/// -------
@@ -1856,8 +2027,9 @@ fn verify_evm<'a>(
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
srs_path=None,
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier_aggr(
py: Python,
aggregation_settings: Vec<PathBuf>,
@@ -1866,9 +2038,9 @@ fn create_evm_verifier_aggr(
abi_path: PathBuf,
logrows: u32,
srs_path: Option<PathBuf>,
render_vk_seperately: bool,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_aggregate_verifier(
vk_path,
srs_path,
@@ -1876,7 +2048,7 @@ fn create_evm_verifier_aggr(
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
reusable,
)
.await
.map_err(|e| {
@@ -1888,15 +2060,19 @@ fn create_evm_verifier_aggr(
})
}
// Define a function to gather stub information.
define_stub_info_gatherer!(stub_info);
// Python Module
#[pymodule]
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init();
m.add_class::<PyRunArgs>()?;
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_class::<PyCommitments>()?;
m.add_class::<PyInputType>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
@@ -1918,6 +2094,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
@@ -1925,8 +2102,8 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
@@ -1935,3 +2112,48 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
Ok(())
}
impl pyo3_stub_gen::PyStubType for CalibrationTarget {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ProofType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for TranscriptType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for CheckMode {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ContractType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}

580
src/bindings/universal.rs Normal file
View File

@@ -0,0 +1,580 @@
use halo2_proofs::{
plonk::*,
poly::{
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
multiopen::{ProverIPA, VerifierIPA},
strategy::SingleStrategy as IPASingleStrategy,
},
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
};
use std::fmt::Display;
use std::io::BufReader;
use std::str::FromStr;
use crate::{
circuit::region::RegionSettings,
graph::GraphSettings,
pfsys::{
create_proof_circuit,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
},
tensor::TensorType,
CheckMode, Commitments, EZKLError as InnerEZKLError,
};
use crate::graph::{GraphCircuit, GraphWitness};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::{FromUniformBytes, PrimeField},
};
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
/// Wrapper around the Error Message
#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))]
#[derive(Debug)]
pub enum EZKLError {
/// Some Comment
InternalError(String),
}
impl Display for EZKLError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EZKLError::InternalError(e) => write!(f, "Internal error: {}", e),
}
}
}
impl From<InnerEZKLError> for EZKLError {
fn from(e: InnerEZKLError) -> Self {
EZKLError::InternalError(e.to_string())
}
}
/// Encode verifier calldata from proof and ethereum vk_address
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn encode_verifier_calldata(
// TODO - shuold it be pub(crate) or pub or pub(super)?
proof: Vec<u8>,
vk_address: Option<Vec<u8>>,
) -> Result<Vec<u8>, EZKLError> {
let snark: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address {
let array: [u8; 20] =
serde_json::from_slice(&vk_address[..]).map_err(InnerEZKLError::from)?;
Some(array)
} else {
None
};
let flattened_instances = snark.instances.into_iter().flatten();
let encoded = encode_calldata(
vk_address,
&snark.proof,
&flattened_instances.collect::<Vec<_>>(),
);
Ok(encoded)
}
/// Generate witness from compiled circuit and input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
})?;
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
let mut input = circuit
.load_graph_input(&input)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
None,
None,
RegionSettings::all_true(
circuit.settings().run_args.decomp_base,
circuit.settings().run_args.decomp_legs,
),
)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
serde_json::to_vec(&witness)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
}
/// Generate verifying key from compiled circuit, and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_vk(
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
compress_selectors: bool,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let vk = create_vk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(
&mut serialized_vk,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
Ok(serialized_vk)
}
/// Generate proving key from vk, compiled circuit and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn gen_pk(
vk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
let pk = create_pk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk, &circuit, &params)
.map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?;
let mut serialized_pk = Vec::new();
pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?;
Ok(serialized_pk)
}
/// Verify proof with vk, proof json, circuit settings json and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn verify(
proof: Vec<u8>,
vk: Vec<u8>,
settings: Vec<u8>,
srs: Vec<u8>,
) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?;
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = BufReader::new(&srs[..]);
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> = get_params(&mut reader)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(EZKLError::InternalError(format!(
"Verification failed: {}",
e
))),
}
}
/// Verify aggregate proof with vk, proof, circuit settings and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn verify_aggr(
proof: Vec<u8>,
vk: Vec<u8>,
logrows: u64,
srs: Vec<u8>,
commitment: &str,
) -> Result<bool, EZKLError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment)
.map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?;
let orig_n = 1 << logrows;
let mut reader = BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)),
)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
result
.map(|_| true)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))
}
/// Prove in browser with compiled circuit, witness json, proving key, and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn prove(
witness: Vec<u8>,
pk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
#[cfg(feature = "det-prove")]
log::set_max_level(log::LevelFilter::Debug);
#[cfg(not(feature = "det-prove"))]
log::set_max_level(log::LevelFilter::Info);
let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
circuit
.load_graph_witness(&data)
.map_err(InnerEZKLError::from)?;
let public_inputs = circuit
.prepare_public_inputs(&data)
.map_err(InnerEZKLError::from)?;
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
let mut reader = BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
KZGCommitmentScheme<Bn256>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
KZGSingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::KZG,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
IPACommitmentScheme<G1Affine>,
_,
ProverIPA<_>,
VerifierIPA<_>,
IPASingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::IPA,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
}
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?)
}
/// Validate the witness json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the compiled circuit
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
})?;
Ok(true)
}
/// Validate the input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::graph::input::GraphData =
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the proof json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the verifying key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
Ok(true)
}
/// Validate the proving key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
Ok(true)
}
/// Validate the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub(crate) fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let _: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize params: {}", e))
})?;
Ok(true)
}
// HELPER FUNCTIONS
fn get_params<
Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>,
>(
mut reader: &mut BufReader<&[u8]>,
) -> Result<Scheme, EZKLError> {
halo2_proofs::poly::commitment::Params::<G1Affine>::read(&mut reader)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)))
}
/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_vk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the verifying key
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
Ok(vk)
}
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_pk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
vk: VerifyingKey<Scheme::Curve>,
circuit: &C,
params: &'_ Scheme::ParamsProver,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the proving key
let pk = keygen_pk(params, vk, &empty_circuit)?;
Ok(pk)
}

379
src/bindings/wasm.rs Normal file
View File

@@ -0,0 +1,379 @@
use crate::{
circuit::modules::{
polycommit::PolyCommitChip,
poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
},
Module,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
GraphSettings,
},
};
use console_error_panic_hook;
use halo2_proofs::{
plonk::*,
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
};
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
};
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use crate::bindings::universal::{
compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness,
input_validation, pk_validation, proof_validation, settings_validation, srs_validation,
verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError,
};
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
impl From<ExternalEZKLError> for JsError {
fn from(e: ExternalEZKLError) -> Self {
JsError::new(&format!("{}", e))
}
}
#[wasm_bindgen]
/// Initialize logger for wasm
pub fn init_logger() {
log::set_logger(&DEFAULT_LOGGER).unwrap();
}
#[wasm_bindgen]
/// Initialize panic hook for wasm
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
/// Wrapper around the halo2 encode call data method
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn encodeVerifierCalldata(
proof: wasm_bindgen::Clamped<Vec<u8>>,
vk_address: Option<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&felt_to_integer_rep(felt))
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
))
}
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let int_rep = felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToFelt(
mut input: f64,
scale: crate::Scale,
input_type: &str,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
crate::circuit::InputType::roundtrip(
&crate::circuit::InputType::from_str(input_type)
.map_err(|e| JsError::new(&format!("{}", e)))?,
&mut input,
);
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Generate a kzg commitment.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn kzgCommit(
message: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let mut reader = std::io::BufReader::new(&params_ser[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&params,
);
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice
let buffer: &[u8] = &buffer;
// Divide the buffer into chunks of 64 bytes
let chunks = buffer.chunks_exact(16);
// Get the remainder
let remainder = chunks.remainder();
// Add 0s to the remainder to make it 64 bytes
let mut remainder = remainder.to_vec();
// Collect chunks into a Vec<[u8; 16]>.
let chunks: Result<Vec<[u8; 16]>, JsError> = chunks
.map(|slice| {
let array: [u8; 16] = slice
.try_into()
.map_err(|_| JsError::new("failed to slice input chunks"))?;
Ok(array)
})
.collect();
let mut chunks = chunks?;
if remainder.len() != 0 {
remainder.resize(16, 0);
// Convert the Vec<u8> to [u8; 16]
let remainder_array: [u8; 16] = remainder
.try_into()
.map_err(|_| JsError::new("failed to slice remainder"))?;
// append the remainder to the chunks
chunks.push(remainder_array);
}
// Convert each chunk to a field element
let field_elements: Vec<Fr> = chunks
.iter()
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&field_elements)
.map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?,
))
}
/// Generate a poseidon hash in browser. Input message
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn poseidonHash(
message: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)),
)?))
}
/// Generate a witness file from input.json, compiled model and a settings.json file.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genWitness(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
input: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_witness(compiled_circuit.0, input.0).map_err(JsError::from)
}
/// Generate verifying key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genVk(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
compress_selectors: bool,
) -> Result<Vec<u8>, JsError> {
gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from)
}
/// Generate proving key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genPk(
vk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from)
}
/// Verify proof in browser using wasm
#[wasm_bindgen]
pub fn verify(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
}
/// Verify aggregate proof in browser using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from)
}
/// Prove in browser using wasm
#[wasm_bindgen]
pub fn prove(
witness: wasm_bindgen::Clamped<Vec<u8>>,
pk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from)
}
// VALIDATION FUNCTIONS
/// Witness file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn witnessValidation(witness: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
witness_validation(witness.0).map_err(JsError::from)
}
/// Compiled circuit validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn compiledCircuitValidation(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from)
}
/// Input file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn inputValidation(input: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
input_validation(input.0).map_err(JsError::from)
}
/// Proof file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn proofValidation(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
proof_validation(proof.0).map_err(JsError::from)
}
/// Vk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vkValidation(
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
vk_validation(vk.0, settings.0).map_err(JsError::from)
}
/// Pk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn pkValidation(
pk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
pk_validation(pk.0, settings.0).map_err(JsError::from)
}
/// Settings file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn settingsValidation(settings: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
settings_validation(settings.0).map_err(JsError::from)
}
/// Srs file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
srs_validation(srs.0).map_err(JsError::from)
}
/// HELPER FUNCTIONS
pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
for &b in arr.iter().rev() {
n <<= 8;
n |= b as u128;
}
n
}

View File

@@ -219,7 +219,7 @@ mod tests {
fn polycommit_chip_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
{
@@ -247,7 +247,7 @@ mod tests {
#[test]
#[ignore]
fn polycommit_chip_much_longer_input() {
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
let rng = rand::rngs::OsRng;

View File

@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
Self::configure_with_cols(
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
let instance = meta.instance_column();
@@ -560,7 +554,7 @@ mod tests {
fn hash_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
{

View File

@@ -8,12 +8,12 @@ use halo2_proofs::{
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
conversion::{FromPyObject, IntoPy},
exceptions::PyValueError,
prelude::*,
types::PyString,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use crate::{
@@ -22,7 +22,7 @@ use crate::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, marker::PhantomData};
@@ -49,6 +49,7 @@ impl std::fmt::Display for CheckMode {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for CheckMode {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
@@ -74,6 +75,16 @@ impl FromStr for CheckMode {
}
}
impl CheckMode {
/// Returns the value of the check mode
pub fn is_safe(&self) -> bool {
match self {
CheckMode::SAFE => true,
CheckMode::UNSAFE => false,
}
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
@@ -88,6 +99,7 @@ impl std::fmt::Display for Tolerance {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
@@ -136,10 +148,9 @@ impl IntoPy<PyObject> for CheckMode {
#[cfg(feature = "python-bindings")]
/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python)
impl<'source> FromPyObject<'source> for CheckMode {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
match trystr.to_lowercase().as_str() {
"safe" => Ok(CheckMode::SAFE),
"unsafe" => Ok(CheckMode::UNSAFE),
_ => Err(PyValueError::new_err("Invalid value for CheckMode")),
@@ -158,8 +169,8 @@ impl IntoPy<PyObject> for Tolerance {
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok((val, scale)) = ob.extract::<(f32, f32)>() {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
Ok(Tolerance {
val,
scale: utils::F32(scale),
@@ -174,7 +185,7 @@ impl<'source> FromPyObject<'source> for Tolerance {
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
pub lookup_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub table_selectors: Vec<Selector>,
/// Inputs:
@@ -204,15 +215,16 @@ impl DynamicLookups {
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, usize), Selector>,
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
pub output_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
pub outputs: Vec<VarTensor>,
}
impl Shuffles {
@@ -223,9 +235,13 @@ impl Shuffles {
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
output_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
outputs: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
@@ -327,7 +343,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {
@@ -363,6 +379,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
if inputs[0].num_cols() != output.num_cols() {
log::warn!("input and output shapes do not match");
}
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
log::warn!("input number of inner columns do not match");
}
if inputs[0].num_inner_cols() != output.num_inner_cols() {
log::warn!("input and output number of inner columns do not match");
}
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
@@ -570,9 +592,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
// this is 0 if the index is the same as the column index (starting from 1)
let col_expr = sel.clone()
* table
* (table
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier =
table.selector_constructor.get_selector_val_at_idx(col_idx);
@@ -604,6 +626,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
@@ -643,57 +699,73 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of inner columns
if tables
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
.windows(2)
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"tables inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for q in 0..tables[0].num_blocks() {
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((x, y))
.or_insert(s_lookup);
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((q, (x, y)))
.or_insert(s_lookup);
}
}
self.dynamic_lookups.table_selectors.push(s_ltable);
}
self.dynamic_lookups.table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookups.tables.is_empty() {
@@ -713,8 +785,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
inputs: &[VarTensor; 3],
outputs: &[VarTensor; 3],
) -> Result<(), CircuitError>
where
F: Field,
@@ -725,63 +797,78 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
for t in outputs.iter() {
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of blocks
if outputs
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
.windows(2)
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"outputs inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
let s_reference = cs.complex_selector();
for q in 0..outputs[0].num_blocks() {
let s_output = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
cs.lookup_any("shuffle", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_outputq = cs.query_selector(s_output);
let mut input_queries = vec![one.clone()];
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let mut output_queries = vec![one.clone()];
for output in outputs {
output_queries.push(match output {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.shuffles
.input_selectors
.entry((x, y))
.or_insert(s_input);
expression
});
self.shuffles
.input_selectors
.entry((q, (x, y)))
.or_insert(s_input);
}
}
self.shuffles.output_selectors.push(s_output);
}
self.shuffles.reference_selectors.push(s_reference);
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
if self.shuffles.outputs.is_empty() {
debug!("assigning shuffles output");
self.shuffles.outputs = outputs.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
@@ -851,9 +938,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
let default_x = range_check.get_first_element(col_idx);
let col_expr = sel.clone()
* range_check
* (range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier = range_check
.selector_constructor
@@ -876,6 +963,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);

View File

@@ -1,6 +1,6 @@
use std::convert::Infallible;
use crate::tensor::TensorError;
use crate::{fieldutils::IntegerRep, tensor::TensorError};
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
@@ -57,7 +57,7 @@ pub enum CircuitError {
InvalidConversion(#[from] Infallible),
/// Invalid min/max lookup range
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
InvalidMinMaxRange(i64, i64),
InvalidMinMaxRange(IntegerRep, IntegerRep),
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
@@ -81,7 +81,7 @@ pub enum CircuitError {
MissingSelectors(String),
/// Table lookup error
#[error("value ({0}) out of range: ({1}, {2})")]
TableOOR(i64, i64, i64),
TableOOR(IntegerRep, IntegerRep, IntegerRep),
/// Loookup not configured
#[error("lookup not configured: {0}")]
LookupNotConfigured(String),
@@ -91,4 +91,16 @@ pub enum CircuitError {
/// Missing layout
#[error("missing layout for op: {0}")]
MissingLayout(String),
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
#[error("invalid input type {0}")]
/// Invalid input type
InvalidInputType(String),
#[error("an element is missing from the shuffled version of the tensor")]
/// An element is missing from the shuffled version of the tensor
MissingShuffleElement,
}

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::i64_to_felt,
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -13,14 +13,38 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Ln {
scale: utils::F32,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
},
Floor {
scale: utils::F32,
legs: usize,
},
Round {
scale: utils::F32,
legs: usize,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
use_range_check_for_int: bool,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
@@ -45,6 +69,8 @@ pub enum HybridOp {
ReduceArgMin {
dim: usize,
},
Max,
Min,
Softmax {
input_scale: utils::F32,
output_scale: utils::F32,
@@ -71,12 +97,19 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
vec![0, 1]
}
_ => vec![],
}
}
@@ -88,21 +121,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
fn as_string(&self) -> String {
match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
HybridOp::Max => "MAX".to_string(),
HybridOp::Min => "MIN".to_string(),
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => format!(
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
input_scale, output_scale, use_range_check_for_int
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
@@ -135,10 +179,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
HybridOp::GreaterEqual => "GREATEREQUAL".into(),
HybridOp::Less => "LESS".into(),
HybridOp::LessEqual => "LESSEQUAL".into(),
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
HybridOp::LessEqual => "LESSEQUAL".to_string(),
HybridOp::Equals => "EQUALS".into(),
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim, largest } => {
@@ -157,6 +201,34 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Floor { scale, legs } => {
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Round { scale, legs } => {
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
HybridOp::SumPool {
padding,
stride,
@@ -174,42 +246,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
if input_scale.0.fract() == 0.0
&& output_scale.0.fract() == 0.0
&& *use_range_check_for_int
{
layouts::recip(
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
layouts::div(
config,
region,
values[..].try_into()?,
i64_to_felt(input_scale.0 as i64),
i64_to_felt(output_scale.0 as i64),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Recip {
input_scale: *input_scale,
output_scale: *output_scale,
},
)?
}
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
config,
region,
values[..].try_into()?,
i64_to_felt(denom.0 as i64),
integer_rep_to_felt(denom.0 as IntegerRep),
)?
} else {
layouts::nonlinearity(
@@ -296,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Softmax {
output_scale,
input_scale,
..
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -3,9 +3,8 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i64, i64_to_felt},
graph::multiplier_to_scale,
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
tensor::{self, Tensor, TensorError, TensorType},
};
use super::Op;
@@ -15,225 +14,142 @@ use halo2curves::ff::PrimeField;
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Abs,
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
ReLU,
Max {
scale: utils::F32,
a: utils::F32,
},
Min {
scale: utils::F32,
a: utils::F32,
},
Ceil {
scale: utils::F32,
},
Floor {
scale: utils::F32,
},
Round {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
Rsqrt {
scale: utils::F32,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Ln {
scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Cos {
scale: utils::F32,
},
ACos {
scale: utils::F32,
},
Cosh {
scale: utils::F32,
},
ACosh {
scale: utils::F32,
},
Sin {
scale: utils::F32,
},
ASin {
scale: utils::F32,
},
Sinh {
scale: utils::F32,
},
ASinh {
scale: utils::F32,
},
Tan {
scale: utils::F32,
},
ATan {
scale: utils::F32,
},
Tanh {
scale: utils::F32,
},
ATanh {
scale: utils::F32,
},
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
Sign,
KroneckerDelta,
Pow {
scale: utils::F32,
a: utils::F32,
},
HardSwish {
scale: utils::F32,
},
Div { denom: utils::F32 },
IsOdd,
PowersOfTwo { scale: utils::F32 },
Ln { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Exp { scale: utils::F32, base: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
ACosh { scale: utils::F32 },
Sin { scale: utils::F32 },
ASin { scale: utils::F32 },
Sinh { scale: utils::F32 },
ASinh { scale: utils::F32 },
Tan { scale: utils::F32 },
ATan { scale: utils::F32 },
Tanh { scale: utils::F32 },
ATanh { scale: utils::F32 },
Erf { scale: utils::F32 },
Pow { scale: utils::F32, a: utils::F32 },
HardSwish { scale: utils::F32 },
}
impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i64;
let range = range as IntegerRep;
(-range, range)
}
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
LookupOp::IsOdd => "is_odd".to_string(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale, base } => format!("exp_{}_{}", scale, base),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
LookupOp::Sin { scale } => format!("sin_{}", scale),
LookupOp::ASin { scale } => format!("asin_{}", scale),
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
LookupOp::Tan { scale } => format!("tan_{}", scale),
LookupOp::ATan { scale } => format!("atan_{}", scale),
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i64(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())),
LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())),
LookupOp::RoundHalfToEven { scale } => Ok(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)),
LookupOp::Max { scale, a } => Ok(tensor::ops::nonlinearities::max(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Min { scale, a } => Ok(tensor::ops::nonlinearities::min(
&x,
scale.0.into(),
a.0.into(),
)),
LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)),
LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than(
&x,
f32::from(*a).into(),
)),
LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
&x,
f32::from(*a).into(),
)),
LookupOp::GreaterThanEqual { a } => Ok(
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
),
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*denom).into(),
)),
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res =
match &self {
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::PowersOfTwo { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
}
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale, base } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::exp(&x, scale.into(), base.into()),
),
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
LookupOp::ACos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acos(&x, scale.into()))
}
LookupOp::Cosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cosh(&x, scale.into()))
}
LookupOp::ACosh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::acosh(&x, scale.into()))
}
LookupOp::Sin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sin(&x, scale.into()))
}
LookupOp::ASin { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asin(&x, scale.into()))
}
LookupOp::Sinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sinh(&x, scale.into()))
}
LookupOp::ASinh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::asinh(&x, scale.into()))
}
LookupOp::Tan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tan(&x, scale.into()))
}
LookupOp::ATan { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atan(&x, scale.into()))
}
LookupOp::ATanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::atanh(&x, scale.into()))
}
LookupOp::Tanh { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::tanh(&x, scale.into()))
}
LookupOp::HardSwish { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
LookupOp::LeakyReLU { slope: a } => {
Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())),
LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())),
LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())),
LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())),
LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())),
LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())),
LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())),
LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())),
LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())),
LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())),
LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())),
LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())),
LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())),
LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())),
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
LookupOp::HardSwish { scale } => {
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
let output = res.map(|x| i64_to_felt(x));
let output = res.map(|x| integer_rep_to_felt(x));
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
@@ -242,37 +158,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::Abs => "ABS".into(),
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
} => format!(
"RECIP(input_scale={}, output_scale={})",
input_scale, output_scale
),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::ReLU => "RELU".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
LookupOp::IsOdd => "IS_ODD".to_string(),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
LookupOp::Exp { scale, base } => format!("EXP(scale={}, base={})", scale, base),
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
LookupOp::Tanh { scale } => format!("TANH(scale={})", scale),
@@ -305,20 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::Sign
| LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
| LookupOp::GreaterThanEqual { .. }
| LookupOp::LessThanEqual { .. }
| LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
let scale = inputs_scale[0];
Ok(scale)
}

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -31,12 +31,12 @@ pub use errors::CircuitError;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Returns a string representation of the operation.
@@ -75,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
fn as_any(&self) -> &dyn Any;
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -105,7 +105,10 @@ impl InputType {
}
///
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone + std::fmt::Debug>(
&self,
input: &mut T,
) {
match self {
InputType::Bool => {
let boolean_input = input.clone().to_i64().unwrap();
@@ -118,7 +121,7 @@ impl InputType {
*input = T::from_f32(f32_input).unwrap();
}
InputType::F32 => {
let f32_input = input.clone().to_f32().unwrap();
let f32_input: f32 = input.clone().to_f32().unwrap();
*input = T::from_f32(f32_input).unwrap();
}
InputType::F64 => {
@@ -133,6 +136,22 @@ impl InputType {
}
}
impl std::str::FromStr for InputType {
type Err = CircuitError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"bool" => Ok(InputType::Bool),
"f16" => Ok(InputType::F16),
"f32" => Ok(InputType::F32),
"f64" => Ok(InputType::F64),
"int" => Ok(InputType::Int),
"tdim" => Ok(InputType::TDim),
e => Err(CircuitError::InvalidInputType(e.to_string())),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {
@@ -142,7 +161,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.scale)
}
@@ -197,7 +216,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(0)
}
@@ -224,7 +243,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -234,7 +253,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash +
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -255,7 +274,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
self.raw_values = Tensor::new(None, &[0]).unwrap();
}
///
/// Pre-assign a value
pub fn pre_assign(&mut self, val: ValTensor<F>) {
self.pre_assigned_val = Some(val)
}
@@ -267,8 +286,7 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
+ IntoI64,
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {

Some files were not shown because too many files have changed in this diff Show More