mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
55 Commits
release-v1
...
ac/patch-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5290045f06 | ||
|
|
a0078bef6a | ||
|
|
d0ba505baa | ||
|
|
f35688917d | ||
|
|
7ae541ed35 | ||
|
|
675628cd08 | ||
|
|
4fe7290240 | ||
|
|
3e027db9b6 | ||
|
|
e566acc22a | ||
|
|
75ea99e81d | ||
|
|
c5354c382d | ||
|
|
bdcba5ca61 | ||
|
|
6752a05f19 | ||
|
|
03aefb85eb | ||
|
|
e86caca8b6 | ||
|
|
c839a30ae6 | ||
|
|
352812b9ac | ||
|
|
d48d0b0b3e | ||
|
|
8b223354cc | ||
|
|
caa6ef8e16 | ||
|
|
c4354c10a5 | ||
|
|
c1ce8c88d0 | ||
|
|
876a9584a1 | ||
|
|
7d7f049cc4 | ||
|
|
96f3fd94b2 | ||
|
|
6263510c56 | ||
|
|
f5b8ae3213 | ||
|
|
b2e4e414f0 | ||
|
|
0b0199e2b7 | ||
|
|
5e169bdd17 | ||
|
|
64cbcb3f7e | ||
|
|
ee17f0ff9a | ||
|
|
ee55e7dc19 | ||
|
|
5df83886c7 | ||
|
|
061ae89c01 | ||
|
|
0fc1c3eecd | ||
|
|
85302453d9 | ||
|
|
523c77c912 | ||
|
|
948e5cd4b9 | ||
|
|
00155e585f | ||
|
|
0876faa12c | ||
|
|
a3c131dac0 | ||
|
|
fd9c2305ac | ||
|
|
a0060f341d | ||
|
|
17f1d42739 | ||
|
|
ebaee9e2b1 | ||
|
|
d51cba589a | ||
|
|
1cb1b6e143 | ||
|
|
d2b683b527 | ||
|
|
a06b09ef1f | ||
|
|
e5aa48fbd6 | ||
|
|
64fbc8a1c9 | ||
|
|
c9f9d17f16 | ||
|
|
b49b0487c4 | ||
|
|
61b7a8e9b5 |
@@ -1,4 +0,0 @@
|
||||
[target.wasm32-unknown-unknown]
|
||||
runner = 'wasm-bindgen-test-runner'
|
||||
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
|
||||
"link-arg=--max-memory=4294967296"]
|
||||
17
.cargo/config.toml
Normal file
17
.cargo/config.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[target.wasm32-unknown-unknown]
|
||||
runner = 'wasm-bindgen-test-runner'
|
||||
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
|
||||
"link-arg=--max-memory=4294967296"]
|
||||
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
55
.github/workflows/benchmarks.yml
vendored
55
.github/workflows/benchmarks.yml
vendored
@@ -6,22 +6,15 @@ on:
|
||||
description: "Test scenario tags"
|
||||
|
||||
jobs:
|
||||
bench_elgamal:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Bench elgamal
|
||||
run: cargo bench --verbose --bench elgamal
|
||||
|
||||
bench_poseidon:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -31,10 +24,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench poseidon
|
||||
|
||||
bench_einsum_accum_matmul:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -44,10 +41,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_einsum_matmul
|
||||
|
||||
bench_accum_matmul_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -57,10 +58,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu
|
||||
|
||||
bench_accum_matmul_relu_overflow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -70,10 +75,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu_overflow
|
||||
|
||||
bench_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -83,10 +92,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench relu
|
||||
|
||||
bench_accum_dot:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -96,10 +109,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_dot
|
||||
|
||||
bench_accum_conv:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -109,10 +126,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_conv
|
||||
|
||||
bench_accum_sumpool:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -122,10 +143,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sumpool
|
||||
|
||||
bench_pairwise_add:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -135,10 +160,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
|
||||
bench_accum_sum:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -148,10 +177,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
bench_pairwise_pow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
|
||||
19
.github/workflows/engine.yml
vendored
19
.github/workflows/engine.yml
vendored
@@ -15,17 +15,25 @@ defaults:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
@@ -48,7 +56,7 @@ jobs:
|
||||
run: |
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${{ github.ref_name }}",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
@@ -181,21 +189,26 @@ jobs:
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
|
||||
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
|
||||
|
||||
4
.github/workflows/large-tests.yml
vendored
4
.github/workflows/large-tests.yml
vendored
@@ -6,9 +6,13 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
|
||||
10
.github/workflows/pypi-gpu.yml
vendored
10
.github/workflows/pypi-gpu.yml
vendored
@@ -18,12 +18,17 @@ defaults:
|
||||
jobs:
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: GPU
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -34,6 +39,7 @@ jobs:
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -98,14 +104,14 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
51
.github/workflows/pypi.yml
vendored
51
.github/workflows/pypi.yml
vendored
@@ -16,6 +16,8 @@ defaults:
|
||||
|
||||
jobs:
|
||||
macos:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -23,6 +25,8 @@ jobs:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -45,6 +49,13 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
@@ -62,6 +73,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -69,6 +82,8 @@ jobs:
|
||||
target: [x64, x86]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -107,6 +122,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -114,6 +131,8 @@ jobs:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -168,7 +187,7 @@ jobs:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# TODO: There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
@@ -220,6 +239,8 @@ jobs:
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -228,11 +249,21 @@ jobs:
|
||||
- x86_64-unknown-linux-musl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -242,7 +273,6 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -276,6 +306,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -283,10 +315,10 @@ jobs:
|
||||
platform:
|
||||
- target: aarch64-unknown-linux-musl
|
||||
arch: aarch64
|
||||
- target: armv7-unknown-linux-musleabihf
|
||||
arch: armv7
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -308,7 +340,7 @@ jobs:
|
||||
manylinux: musllinux_1_2
|
||||
args: --release --out dist --features python-bindings
|
||||
|
||||
- uses: uraimo/run-on-arch-action@v2.5.0
|
||||
- uses: uraimo/run-on-arch-action@v2.8.1
|
||||
name: Install built wheel
|
||||
with:
|
||||
arch: ${{ matrix.platform.arch }}
|
||||
@@ -350,25 +382,28 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
|
||||
26
.github/workflows/release.yml
vendored
26
.github/workflows/release.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
- "*"
|
||||
jobs:
|
||||
create-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: create-release
|
||||
runs-on: ubuntu-22.04
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -33,6 +36,9 @@ jobs:
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -50,6 +56,9 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -91,6 +100,10 @@ jobs:
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
issues: write
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -132,6 +145,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -181,9 +196,18 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
429
.github/workflows/rust.yml
vendored
429
.github/workflows/rust.yml
vendored
@@ -19,11 +19,31 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
fr-age-test:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: fr age Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -33,9 +53,13 @@ jobs:
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -45,9 +69,13 @@ jobs:
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -65,45 +93,51 @@ jobs:
|
||||
- name: Library tests (original lookup)
|
||||
run: cargo nextest run --lib --verbose --no-default-features --features ezkl
|
||||
|
||||
ultra-overflow-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Install wasm32-wasi
|
||||
run: rustup target add wasm32-wasi
|
||||
- name: Install cargo-wasi
|
||||
run: cargo install cargo-wasi
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# ultra-overflow-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - uses: mwilliamson/setup-wasmtime-action@v2
|
||||
# with:
|
||||
# wasmtime-version: "3.0.1"
|
||||
# - name: Install wasm32-wasi
|
||||
# run: rustup target add wasm32-wasi
|
||||
# - name: Install cargo-wasi
|
||||
# run: cargo install cargo-wasi
|
||||
# # - name: Matmul overflow (wasi)
|
||||
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# # - name: Conv overflow (wasi)
|
||||
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: lookup overflow
|
||||
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Matmul overflow
|
||||
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv overflow
|
||||
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv + relu overflow
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -134,9 +168,13 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -167,9 +205,13 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -183,16 +225,22 @@ jobs:
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
@@ -205,28 +253,15 @@ jobs:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
tutorial:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Circuit Render
|
||||
run: cargo nextest run --release --verbose tests::tutorial_
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -238,10 +273,12 @@ jobs:
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
- name: kzg params
|
||||
@@ -284,10 +321,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -298,6 +339,8 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -331,8 +374,8 @@ jobs:
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -362,23 +405,67 @@ jobs:
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
# contents: read
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18
|
||||
# - uses: actions/checkout@v3
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@v2
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -435,46 +522,52 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
prove-and-verify-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -487,29 +580,35 @@ jobs:
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
prove-and-verify-aggr-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG )tests
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG tests
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -523,10 +622,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -544,10 +647,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -561,10 +668,14 @@ jobs:
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -587,10 +698,14 @@ jobs:
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -607,8 +722,6 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
@@ -619,6 +732,8 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
@@ -640,6 +755,8 @@ jobs:
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
@@ -662,6 +779,12 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
@@ -677,7 +800,87 @@ jobs:
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
33
.github/workflows/static-analysis.yml
vendored
Normal file
33
.github/workflows/static-analysis.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Static Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
# Run Zizmor static analysis
|
||||
|
||||
- name: Install Zizmor
|
||||
run: cargo install --locked zizmor
|
||||
|
||||
- name: Run Zizmor Analysis
|
||||
run: zizmor .
|
||||
|
||||
|
||||
|
||||
134
.github/workflows/swift-pm.yml
vendored
Normal file
134
.github/workflows/swift-pm.yml
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
name: Build and Publish EZKL iOS SPM package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
# Only support SemVer versioning tags
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
- '[0-9]+.[0-9]+.[0-9]+'
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
run: |
|
||||
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
|
||||
TAG="${RELEASE_TAG}"
|
||||
echo "Original TAG: $TAG"
|
||||
# Remove leading 'v' if present to match the Swift Package Manager version format.
|
||||
NEW_TAG=${TAG#v}
|
||||
echo "Stripped TAG: $NEW_TAG"
|
||||
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Rust (nightly)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift-package repository
|
||||
run: |
|
||||
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
|
||||
echo "no_changes=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "no_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Set up Xcode environment
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
- name: Setup Git
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git config user.name "GitHub Action"
|
||||
git config user.email "action@github.com"
|
||||
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
|
||||
|
||||
- name: Commit and Push Changes
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git add Sources/EzklCoreBindings Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
if ! git push origin; then
|
||||
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag the latest commit
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
source $GITHUB_ENV
|
||||
# Tag the latest commit on the current branch
|
||||
if git rev-parse "$TAG" >/dev/null 2>&1; then
|
||||
echo "Tag $TAG already exists locally. Skipping tag creation."
|
||||
else
|
||||
git tag "$TAG"
|
||||
fi
|
||||
|
||||
if ! git push origin "$TAG"; then
|
||||
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
2
.github/workflows/tagging.yml
vendored
2
.github/workflows/tagging.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -27,7 +27,6 @@ __pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.py[cod]
|
||||
bin/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
@@ -46,7 +45,7 @@ var/
|
||||
node_modules
|
||||
/dist
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
!tests/assets/pk.key
|
||||
!tests/assets/vk.key
|
||||
docs/python/build
|
||||
!tests/wasm/vk_aggr.key
|
||||
!tests/assets/vk_aggr.key
|
||||
|
||||
1322
Cargo.lock
generated
1322
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
223
Cargo.toml
223
Cargo.toml
@@ -4,6 +4,7 @@ cargo-features = ["profile-rustflags"]
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
default-run = "ezkl"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
@@ -11,87 +12,108 @@ edition = "2021"
|
||||
# Name to be imported within python
|
||||
# Example: import ezkl
|
||||
name = "ezkl"
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd", package = "halo2_proofs" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"] }
|
||||
clap_complete = "4.5.2"
|
||||
serde = { version = "1.0.126", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
], optional = true }
|
||||
log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
|
||||
"circuit-params",
|
||||
] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
itertools = { version = "0.10.3", default-features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"], optional = true }
|
||||
serde = { version = "1.0.126", features = ["derive"] }
|
||||
clap_complete = { version = "4.5.2", optional = true }
|
||||
log = { version = "0.4.17", default-features = false }
|
||||
thiserror = { version = "1.0.38", default-features = false }
|
||||
hex = { version = "0.4.3", default-features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
semver = "1.0.22"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
serde_json = { version = "1.0.97", features = ["float_roundtrip", "raw_value"] }
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
|
||||
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
|
||||
ethabi = "18"
|
||||
indicatif = { version = "0.17.5", features = ["rayon"] }
|
||||
gag = { version = "1.0.0", default_features = false }
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
|
||||
"provider-http",
|
||||
"signers",
|
||||
"contract",
|
||||
"rpc-types-eth",
|
||||
"signer-wallet",
|
||||
"node-bindings",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = [
|
||||
"svm-solc",
|
||||
], optional = true }
|
||||
ethabi = { version = "18", optional = true }
|
||||
indicatif = { version = "0.17.5", features = ["rayon"], optional = true }
|
||||
gag = { version = "1.0.0", default-features = false, optional = true }
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"default-tls",
|
||||
"multipart",
|
||||
"stream",
|
||||
] }
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
tokio-postgres = "0.7.10"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
lazy_static = "1.4.0"
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.35", default_features = false, features = [
|
||||
], optional = true }
|
||||
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
|
||||
tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread"
|
||||
] }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"rt-multi-thread",
|
||||
], optional = true }
|
||||
pyo3 = { version = "0.23.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
|
||||
], default-features = false, optional = true }
|
||||
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
|
||||
], default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
|
||||
# universal bindings
|
||||
uniffi = { version = "=0.28.0", optional = true }
|
||||
getrandom = { version = "0.2.8", optional = true }
|
||||
uniffi_bindgen = { version = "=0.28.0", optional = true }
|
||||
camino = { version = "^1.1", optional = true }
|
||||
uuid = { version = "1.10.0", features = ["v4"], optional = true }
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default_features = false, optional = true }
|
||||
env_logger = { version = "0.10.0", default_features = false, optional = true }
|
||||
chrono = "0.4.31"
|
||||
sha256 = "1.4.0"
|
||||
mimalloc = "0.1"
|
||||
colored = { version = "2.0.0", default-features = false, optional = true }
|
||||
env_logger = { version = "0.10.0", default-features = false, optional = true }
|
||||
chrono = { version = "0.4.31", optional = true }
|
||||
sha256 = { version = "1.4.0", optional = true }
|
||||
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
serde_json = { version = "1.0.97", default-features = false, features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
] }
|
||||
getrandom = { version = "0.2.8", features = ["js"] }
|
||||
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
|
||||
|
||||
@@ -107,6 +129,10 @@ wasm-bindgen-console-logger = "0.1.1"
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
|
||||
criterion = { version = "0.5.1", features = ["html_reports"] }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
uniffi = { version = "0.28", features = ["build"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.3.0"
|
||||
lazy_static = "1.4.0"
|
||||
@@ -120,6 +146,10 @@ shellexpand = "3.1.0"
|
||||
runner = 'wasm-bindgen-test-runner'
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "zero_finder"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_dot"
|
||||
harness = false
|
||||
@@ -158,16 +188,20 @@ harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "relu"
|
||||
name = "sigmoid"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu"
|
||||
name = "relu_lookupless"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_sigmoid"
|
||||
harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu_overflow"
|
||||
name = "accum_matmul_sigmoid_overflow"
|
||||
harness = false
|
||||
|
||||
[[bin]]
|
||||
@@ -176,43 +210,94 @@ test = false
|
||||
bench = false
|
||||
required-features = ["ezkl"]
|
||||
|
||||
[[bin]]
|
||||
name = "ios_gen_bindings"
|
||||
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
|
||||
|
||||
[[bin]]
|
||||
name = "py_stub_gen"
|
||||
required-features = ["python-bindings"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup", "no-banner", "parallel-poly-read"]
|
||||
default = [
|
||||
"ezkl",
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
|
||||
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
|
||||
ezkl = [
|
||||
"onnx",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"log",
|
||||
"colored",
|
||||
"env_logger",
|
||||
"dep:colored",
|
||||
"dep:env_logger",
|
||||
"tabled/color",
|
||||
"serde_json/std",
|
||||
"colored_json",
|
||||
"halo2_proofs/circuit-params",
|
||||
"dep:alloy",
|
||||
"dep:foundry-compilers",
|
||||
"dep:ethabi",
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
"dep:clap",
|
||||
"dep:tosubcommand",
|
||||
]
|
||||
parallel-poly-read = [
|
||||
"halo2_proofs/circuit-params",
|
||||
"halo2_proofs/parallel-poly-read",
|
||||
]
|
||||
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
"halo2_solidity_verifier/mv-lookup",
|
||||
]
|
||||
# precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
metal = ["dep:metal", "dep:objc"]
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
[profile.release]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
"--flexible-inline-max-function-size",
|
||||
"4294967295",
|
||||
]
|
||||
147
abis/DataAttestationSingle.json
Normal file
147
abis/DataAttestationSingle.json
Normal file
@@ -0,0 +1,147 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_scales",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "accountCall",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "contractAddress",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "admin",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "updateAccountCalls",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "updateAdmin",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -1,4 +1,23 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int256[]",
|
||||
"name": "quantized_data",
|
||||
"type": "int256[]"
|
||||
}
|
||||
],
|
||||
"name": "check_is_valid_field_element",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "output",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
@@ -17,12 +36,41 @@
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"name": "quantize_data",
|
||||
"name": "quantize_data_multi",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "int64[]",
|
||||
"internalType": "int256[]",
|
||||
"name": "quantized_data",
|
||||
"type": "int64[]"
|
||||
"type": "int256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "data",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "scales",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"name": "quantize_data_single",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "int256[]",
|
||||
"name": "quantized_data",
|
||||
"type": "int256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
// sets up a new relu table
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&b,
|
||||
&output,
|
||||
&a,
|
||||
BITS,
|
||||
K,
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
MyConfig { base_config }
|
||||
@@ -75,14 +83,18 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
// sets up a new relu table
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&b,
|
||||
&output,
|
||||
&a,
|
||||
BITS,
|
||||
k,
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
MyConfig { base_config }
|
||||
@@ -76,14 +84,18 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
|
||||
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
|
||||
.unwrap();
|
||||
|
||||
@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
|
||||
.unwrap();
|
||||
|
||||
150
benches/relu_lookupless.rs
Normal file
150
benches/relu_lookupless.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NLCircuit {
|
||||
pub input: ValTensor<Fr>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for NLCircuit {
|
||||
type Config = Config<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unsafe {
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
|
||||
) -> Result<(), Error> {
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runrelu(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("relu");
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [4, 8].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
};
|
||||
|
||||
let input: Tensor<Value<Fr>> =
|
||||
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
|
||||
let circuit = NLCircuit {
|
||||
input: ValTensor::from(input.clone()),
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
NLCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = benches;
|
||||
config = Criterion::default().with_plots();
|
||||
targets = runrelu
|
||||
}
|
||||
criterion_main!(benches);
|
||||
@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
@@ -63,9 +63,13 @@ impl Circuit<Fr> for NLCircuit {
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
116
benches/zero_finder.rs
Normal file
116
benches/zero_finder.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::thread;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use halo2curves::{bn256::Fr as F, ff::Field};
|
||||
use maybe_rayon::{
|
||||
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
|
||||
slice::ParallelSlice,
|
||||
};
|
||||
use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
Other,
|
||||
}
|
||||
|
||||
// Helper to generate test data
|
||||
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..size)
|
||||
.map(|_i| {
|
||||
if rng.gen::<f64>() < zero_probability {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else {
|
||||
ValType::Constant(F::ONE) // Or some other non-zero value
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bench_zero_finding(c: &mut Criterion) {
|
||||
let sizes = [
|
||||
1_000, // 1K
|
||||
10_000, // 10K
|
||||
100_000, // 100K
|
||||
256 * 256 * 2, // Our specific case
|
||||
1_000_000, // 1M
|
||||
10_000_000, // 10M
|
||||
];
|
||||
|
||||
let zero_probability = 0.1; // 10% zeros
|
||||
|
||||
let mut group = c.benchmark_group("zero_finding");
|
||||
group.sample_size(10); // Adjust based on your needs
|
||||
|
||||
for &size in &sizes {
|
||||
let data = generate_test_data(size, zero_probability);
|
||||
|
||||
// Benchmark sequential version
|
||||
group.bench_function(format!("sequential_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let result = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark parallel version
|
||||
group.bench_function(format!("parallel_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let result = data
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark chunked parallel version
|
||||
group.bench_function(format!("chunked_parallel_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let num_cores = thread::available_parallelism()
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
let chunk_size = (size / num_cores).max(100);
|
||||
|
||||
let result = data
|
||||
.par_chunks(chunk_size)
|
||||
.enumerate()
|
||||
.flat_map(|(chunk_idx, chunk)| {
|
||||
chunk
|
||||
.par_iter() // Make sure we use par_iter() here
|
||||
.enumerate()
|
||||
.filter_map(move |(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_zero_finding);
|
||||
criterion_main!(benches);
|
||||
7
build.rs
Normal file
7
build.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
fn main() {
|
||||
if cfg!(feature = "ios-bindings-test") {
|
||||
println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features");
|
||||
}
|
||||
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
}
|
||||
@@ -163,6 +163,253 @@ contract SwapProofCommitments {
|
||||
} /// end checkKzgCommits
|
||||
}
|
||||
|
||||
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
* @param the abi encoded function calls to make to the `contractAddress`
|
||||
*/
|
||||
struct AccountCall {
|
||||
address contractAddress;
|
||||
bytes callData;
|
||||
uint256 decimals;
|
||||
}
|
||||
AccountCall public accountCall;
|
||||
|
||||
uint[] scales;
|
||||
|
||||
address public admin;
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 constant INPUT_LEN = 0;
|
||||
|
||||
uint256 constant OUTPUT_LEN = 0;
|
||||
|
||||
uint8 public instanceOffset;
|
||||
|
||||
/**
|
||||
* @dev Initialize the contract with account calls the EZKL model will read from.
|
||||
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
|
||||
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
*/
|
||||
constructor(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals,
|
||||
uint[] memory _scales,
|
||||
uint8 _instanceOffset,
|
||||
address _admin
|
||||
) {
|
||||
admin = _admin;
|
||||
for (uint i; i < _scales.length; i++) {
|
||||
scales.push(1 << _scales[i]);
|
||||
}
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
instanceOffset = _instanceOffset;
|
||||
}
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
}
|
||||
|
||||
function updateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
function populateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) internal {
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
_accountCall.contractAddress = _contractAddresses;
|
||||
_accountCall.callData = _callData;
|
||||
_accountCall.decimals = 10 ** _decimals;
|
||||
accountCall = _accountCall;
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) internal pure returns (uint256 result) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
|
||||
if (prod1 == 0) {
|
||||
return prod0 / denominator;
|
||||
}
|
||||
|
||||
require(denominator > prod1, "Math: mulDiv overflow");
|
||||
|
||||
uint256 remainder;
|
||||
assembly {
|
||||
remainder := mulmod(x, y, denominator)
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
denominator := div(denominator, twos)
|
||||
prod0 := div(prod0, twos)
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
|
||||
result = prod0 * inverse;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param x - One of the elements of the data returned from the account calls
|
||||
* @param _decimals - Number of base 10 decimals to scale the data by.
|
||||
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
|
||||
*
|
||||
*/
|
||||
function quantizeData(
|
||||
int x,
|
||||
uint256 _decimals,
|
||||
uint256 _scale
|
||||
) internal pure returns (int256 quantized_data) {
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint output = mulDiv(uint256(x), _scale, _decimals);
|
||||
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
|
||||
output += 1;
|
||||
}
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
* @dev Make a static call to the account to fetch the data that EZKL reads from.
|
||||
* @param target - The address of the account to make calls to.
|
||||
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
|
||||
*/
|
||||
function staticCall(
|
||||
address target,
|
||||
bytes memory data
|
||||
) internal view returns (bytes memory) {
|
||||
(bool success, bytes memory returndata) = target.staticcall(data);
|
||||
if (success) {
|
||||
if (returndata.length == 0) {
|
||||
require(
|
||||
target.code.length > 0,
|
||||
"Address: call to non-contract"
|
||||
);
|
||||
}
|
||||
return returndata;
|
||||
} else {
|
||||
revert("Address: low-level call failed");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Convert the fixed point quantized data into a field element.
|
||||
* @param x - The quantized data.
|
||||
* @return field_element - The field element.
|
||||
*/
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) internal pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
|
||||
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
|
||||
*/
|
||||
function attestData(uint256[] memory instances) internal view {
|
||||
require(
|
||||
instances.length >= INPUT_LEN + OUTPUT_LEN,
|
||||
"Invalid public inputs length"
|
||||
);
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
uint[] memory _scales = scales;
|
||||
bytes memory returnData = staticCall(
|
||||
_accountCall.contractAddress,
|
||||
_accountCall.callData
|
||||
);
|
||||
int256[] memory x = abi.decode(returnData, (int256[]));
|
||||
uint _offset;
|
||||
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
|
||||
uint field_element = toFieldElement(output);
|
||||
for (uint i = 0; i < x.length; i++) {
|
||||
if (field_element != instances[i + instanceOffset]) {
|
||||
_offset += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
uint length = x.length - _offset;
|
||||
for (uint i = 1; i < length; i++) {
|
||||
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
|
||||
field_element = toFieldElement(output);
|
||||
require(
|
||||
field_element == instances[i + instanceOffset + _offset],
|
||||
"Public input does not match"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Verify the proof with the data attestation.
|
||||
* @param verifier - The address of the verifier contract.
|
||||
* @param encoded - The verifier calldata.
|
||||
*/
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
attestData(getInstancesCalldata(encoded));
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
if (success) {
|
||||
return abi.decode(returndata, (bool));
|
||||
} else {
|
||||
revert("low-level call to verifier failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
// It is designed to read and attest to instances of proofs generated from a specified circuit.
|
||||
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
|
||||
@@ -173,11 +420,11 @@ contract SwapProofCommitments {
|
||||
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
|
||||
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
|
||||
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
|
||||
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
|
||||
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
|
||||
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
|
||||
// then calls the `verifyProof` method to verify the proof on the verifier.
|
||||
|
||||
contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -146,6 +146,8 @@ where
|
||||
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::configure(
|
||||
@@ -156,15 +158,11 @@ where
|
||||
);
|
||||
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::ReLU,
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
@@ -195,11 +193,16 @@ where
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
@@ -221,7 +224,14 @@ where
|
||||
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut x = config
|
||||
@@ -281,7 +291,7 @@ where
|
||||
}
|
||||
|
||||
pub fn runconv() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
const KERNEL_HEIGHT: usize = 5;
|
||||
|
||||
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
// tells the config layer to add an affine op to the circuit gate
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::<F>::configure(
|
||||
cs,
|
||||
&[input.clone(), params.clone()],
|
||||
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::ReLU,
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
@@ -104,11 +103,16 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(
|
||||
@@ -141,7 +145,14 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
println!("x shape: {:?}", x.dims());
|
||||
let mut x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
println!("3");
|
||||
@@ -177,7 +188,14 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
println!("x shape: {:?}", x.dims());
|
||||
let x = config
|
||||
.layer_config
|
||||
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
println!("6");
|
||||
println!("offset: {}", region.row());
|
||||
@@ -212,7 +230,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
}
|
||||
|
||||
pub fn runmlp() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
// parameters
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
|
||||
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -648,10 +648,10 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
771
examples/notebooks/ezkl_demo_batch.ipynb
Normal file
771
examples/notebooks/ezkl_demo_batch.ipynb
Normal file
File diff suppressed because one or more lines are too long
130
examples/notebooks/felt_conversion_test.ipynb
Normal file
130
examples/notebooks/felt_conversion_test.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,279 +1,284 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# note that you can also call the following function to generate random data for the model\n",
|
||||
"# it is functionally equivalent to the code above\n",
|
||||
"ezkl.gen_random_data()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,456 +1,459 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"await ezkl.get_srs(settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
766
examples/notebooks/neural_bow.ipynb
Normal file
766
examples/notebooks/neural_bow.ipynb
Normal file
@@ -0,0 +1,766 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
|
||||
"\n",
|
||||
"1 - NBoW\n",
|
||||
"\n",
|
||||
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"Preparing Data\n",
|
||||
"\n",
|
||||
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
|
||||
"\n",
|
||||
"The steps to take are:\n",
|
||||
"\n",
|
||||
" 1. importing modules\n",
|
||||
" 2. loading data\n",
|
||||
" 3. tokenizing data\n",
|
||||
" 4. creating data splits\n",
|
||||
" 5. creating a vocabulary\n",
|
||||
" 6. numericalizing data\n",
|
||||
" 7. creating the data loaders\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install torchtex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import collections\n",
|
||||
"\n",
|
||||
"import datasets\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchtext\n",
|
||||
"import tqdm\n",
|
||||
"\n",
|
||||
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
|
||||
"\n",
|
||||
"seed = 1234\n",
|
||||
"\n",
|
||||
"np.random.seed(seed)\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"torch.cuda.manual_seed(seed)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data.features\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
|
||||
"\n",
|
||||
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tokenize_example(example, tokenizer, max_length):\n",
|
||||
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
|
||||
" return {\"tokens\": tokens}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"max_length = 256\n",
|
||||
"\n",
|
||||
"train_data = train_data.map(\n",
|
||||
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
|
||||
")\n",
|
||||
"test_data = test_data.map(\n",
|
||||
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# create validation data \n",
|
||||
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
|
||||
"\n",
|
||||
"test_size = 0.25\n",
|
||||
"\n",
|
||||
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
|
||||
"train_data = train_valid_data[\"train\"]\n",
|
||||
"valid_data = train_valid_data[\"test\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
|
||||
"\n",
|
||||
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"min_freq = 5\n",
|
||||
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
|
||||
"\n",
|
||||
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
|
||||
" train_data[\"tokens\"],\n",
|
||||
" min_freq=min_freq,\n",
|
||||
" specials=special_tokens,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
|
||||
"\n",
|
||||
"unk_index = vocab[\"<unk>\"]\n",
|
||||
"pad_index = vocab[\"<pad>\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vocab.set_default_index(unk_index)\n",
|
||||
"\n",
|
||||
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
|
||||
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def numericalize_example(example, vocab):\n",
|
||||
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
|
||||
" return {\"ids\": ids}\n",
|
||||
"\n",
|
||||
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"\n",
|
||||
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
|
||||
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
|
||||
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
|
||||
"\n",
|
||||
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
|
||||
"\n",
|
||||
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
|
||||
"\n",
|
||||
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
|
||||
"\n",
|
||||
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
|
||||
"\n",
|
||||
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_collate_fn(pad_index):\n",
|
||||
" def collate_fn(batch):\n",
|
||||
" batch_ids = [i[\"ids\"] for i in batch]\n",
|
||||
" batch_ids = nn.utils.rnn.pad_sequence(\n",
|
||||
" batch_ids, padding_value=pad_index, batch_first=True\n",
|
||||
" )\n",
|
||||
" batch_label = [i[\"label\"] for i in batch]\n",
|
||||
" batch_label = torch.stack(batch_label)\n",
|
||||
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" return collate_fn\n",
|
||||
"\n",
|
||||
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
|
||||
" collate_fn = get_collate_fn(pad_index)\n",
|
||||
" data_loader = torch.utils.data.DataLoader(\n",
|
||||
" dataset=dataset,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" collate_fn=collate_fn,\n",
|
||||
" shuffle=shuffle,\n",
|
||||
" )\n",
|
||||
" return data_loader\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"batch_size = 512\n",
|
||||
"\n",
|
||||
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
|
||||
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
|
||||
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"class NBoW(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
|
||||
" super().__init__()\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
|
||||
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, ids):\n",
|
||||
" # ids = [batch size, seq len]\n",
|
||||
" embedded = self.embedding(ids)\n",
|
||||
" # embedded = [batch size, seq len, embedding dim]\n",
|
||||
" pooled = embedded.mean(dim=1)\n",
|
||||
" # pooled = [batch size, embedding dim]\n",
|
||||
" prediction = self.fc(pooled)\n",
|
||||
" # prediction = [batch size, output dim]\n",
|
||||
" return prediction\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vocab_size = len(vocab)\n",
|
||||
"embedding_dim = 300\n",
|
||||
"output_dim = len(train_data.unique(\"label\"))\n",
|
||||
"\n",
|
||||
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
|
||||
"\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
|
||||
"\n",
|
||||
"vectors = torchtext.vocab.GloVe()\n",
|
||||
"\n",
|
||||
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters())\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"model = model.to(device)\n",
|
||||
"criterion = criterion.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(data_loader, model, criterion, optimizer, device):\n",
|
||||
" model.train()\n",
|
||||
" epoch_losses = []\n",
|
||||
" epoch_accs = []\n",
|
||||
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
|
||||
" ids = batch[\"ids\"].to(device)\n",
|
||||
" label = batch[\"label\"].to(device)\n",
|
||||
" prediction = model(ids)\n",
|
||||
" loss = criterion(prediction, label)\n",
|
||||
" accuracy = get_accuracy(prediction, label)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" epoch_losses.append(loss.item())\n",
|
||||
" epoch_accs.append(accuracy.item())\n",
|
||||
" return np.mean(epoch_losses), np.mean(epoch_accs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate(data_loader, model, criterion, device):\n",
|
||||
" model.eval()\n",
|
||||
" epoch_losses = []\n",
|
||||
" epoch_accs = []\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
|
||||
" ids = batch[\"ids\"].to(device)\n",
|
||||
" label = batch[\"label\"].to(device)\n",
|
||||
" prediction = model(ids)\n",
|
||||
" loss = criterion(prediction, label)\n",
|
||||
" accuracy = get_accuracy(prediction, label)\n",
|
||||
" epoch_losses.append(loss.item())\n",
|
||||
" epoch_accs.append(accuracy.item())\n",
|
||||
" return np.mean(epoch_losses), np.mean(epoch_accs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_accuracy(prediction, label):\n",
|
||||
" batch_size, _ = prediction.shape\n",
|
||||
" predicted_classes = prediction.argmax(dim=-1)\n",
|
||||
" correct_predictions = predicted_classes.eq(label).sum()\n",
|
||||
" accuracy = correct_predictions / batch_size\n",
|
||||
" return accuracy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_epochs = 10\n",
|
||||
"best_valid_loss = float(\"inf\")\n",
|
||||
"\n",
|
||||
"metrics = collections.defaultdict(list)\n",
|
||||
"\n",
|
||||
"for epoch in range(n_epochs):\n",
|
||||
" train_loss, train_acc = train(\n",
|
||||
" train_data_loader, model, criterion, optimizer, device\n",
|
||||
" )\n",
|
||||
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
|
||||
" metrics[\"train_losses\"].append(train_loss)\n",
|
||||
" metrics[\"train_accs\"].append(train_acc)\n",
|
||||
" metrics[\"valid_losses\"].append(valid_loss)\n",
|
||||
" metrics[\"valid_accs\"].append(valid_acc)\n",
|
||||
" if valid_loss < best_valid_loss:\n",
|
||||
" best_valid_loss = valid_loss\n",
|
||||
" torch.save(model.state_dict(), \"nbow.pt\")\n",
|
||||
" print(f\"epoch: {epoch}\")\n",
|
||||
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
|
||||
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(10, 6))\n",
|
||||
"ax = fig.add_subplot(1, 1, 1)\n",
|
||||
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
|
||||
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
|
||||
"ax.set_xlabel(\"epoch\")\n",
|
||||
"ax.set_ylabel(\"loss\")\n",
|
||||
"ax.set_xticks(range(n_epochs))\n",
|
||||
"ax.legend()\n",
|
||||
"ax.grid()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(10, 6))\n",
|
||||
"ax = fig.add_subplot(1, 1, 1)\n",
|
||||
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
|
||||
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
|
||||
"ax.set_xlabel(\"epoch\")\n",
|
||||
"ax.set_ylabel(\"loss\")\n",
|
||||
"ax.set_xticks(range(n_epochs))\n",
|
||||
"ax.legend()\n",
|
||||
"ax.grid()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
|
||||
"\n",
|
||||
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
|
||||
"\n",
|
||||
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
|
||||
" tokens = tokenizer(text)\n",
|
||||
" ids = vocab.lookup_indices(tokens)\n",
|
||||
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
|
||||
" prediction = model(tensor).squeeze(dim=0)\n",
|
||||
" probability = torch.softmax(prediction, dim=-1)\n",
|
||||
" predicted_class = prediction.argmax(dim=-1).item()\n",
|
||||
" predicted_probability = probability[predicted_class].item()\n",
|
||||
" return predicted_class, predicted_probability"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is terrible!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is great!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is not terrible, it's great!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is not great, it's terrible!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def text_to_tensor(text, tokenizer, vocab, device):\n",
|
||||
" tokens = tokenizer(text)\n",
|
||||
" ids = vocab.lookup_indices(tokens)\n",
|
||||
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
|
||||
" return tensor\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we do onnx stuff to get the data ready for the zk-circuit."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"text = \"This film is terrible!\"\n",
|
||||
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"model.eval()\n",
|
||||
"model.to('cpu')\n",
|
||||
"\n",
|
||||
"model_path = \"network.onnx\"\n",
|
||||
"data_path = \"input.json\"\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(model, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" model_path, # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data_json = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"print(data_json)\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(data_json, open(data_path, 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.logrows = 23\n",
|
||||
"run_args.scale_rebase_multiplier = 10\n",
|
||||
"# inputs should be auditable by all\n",
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"# same with outputs\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(py_run_args=run_args)\n",
|
||||
"assert res == True\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit()\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"res = await ezkl.gen_witness()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.mock()\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"res = ezkl.setup()\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"res = ezkl.prove(proof_path=\"proof.json\")\n",
|
||||
"\n",
|
||||
"print(res)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"res = ezkl.verify()\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also verify it on chain by creating an onchain verifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
|
||||
" !solc-select install 0.8.20\n",
|
||||
" !solc-select use 0.8.20\n",
|
||||
" !solc --version\n",
|
||||
" import os\n",
|
||||
"\n",
|
||||
"# rely on local installation if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" import os\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.create_evm_verifier()\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
|
||||
"\n",
|
||||
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
|
||||
"\n",
|
||||
"Create a new file within remix and copy the verifier code over.\n",
|
||||
"\n",
|
||||
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
|
||||
"\n",
|
||||
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -232,7 +232,7 @@
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n",
|
||||
"run_args.logrows = 15\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
@@ -404,7 +404,7 @@
|
||||
"run_args.output_visibility = \"polycommit\"\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n"
|
||||
"run_args.logrows = 15\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -466,7 +466,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
339
examples/notebooks/reusable_verifier.ipynb
Normal file
339
examples/notebooks/reusable_verifier.ipynb
Normal file
@@ -0,0 +1,339 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reusable Verifiers \n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
|
||||
"\n",
|
||||
"- `1l_mlp sigmoid`\n",
|
||||
"- `1l_mlp relu`\n",
|
||||
"- `1l_conv sigmoid`\n",
|
||||
"- `1l_conv relu`\n",
|
||||
"\n",
|
||||
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
|
||||
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
|
||||
"\n",
|
||||
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we deploy a unique and much smaller verifying key artifact (VKA) contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.onnx\n",
|
||||
"\n",
|
||||
"# Define the models\n",
|
||||
"class MLP_Sigmoid(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLP_Sigmoid, self).__init__()\n",
|
||||
" self.fc = nn.Linear(3, 3)\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc(x)\n",
|
||||
" x = self.sigmoid(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class MLP_Relu(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLP_Relu, self).__init__()\n",
|
||||
" self.fc = nn.Linear(3, 3)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.fc(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class Conv_Sigmoid(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Conv_Sigmoid, self).__init__()\n",
|
||||
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.conv(x)\n",
|
||||
" x = self.sigmoid(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class Conv_Relu(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Conv_Relu, self).__init__()\n",
|
||||
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.conv(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Instantiate the models\n",
|
||||
"mlp_sigmoid = MLP_Sigmoid()\n",
|
||||
"mlp_relu = MLP_Relu()\n",
|
||||
"conv_sigmoid = Conv_Sigmoid()\n",
|
||||
"conv_relu = Conv_Relu()\n",
|
||||
"\n",
|
||||
"# Dummy input tensor for mlp\n",
|
||||
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
|
||||
"input_mlp_path = 'mlp_input.json'\n",
|
||||
"\n",
|
||||
"# Dummy input tensor for conv\n",
|
||||
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
|
||||
"input_conv_path = 'conv_input.json'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
|
||||
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
|
||||
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
|
||||
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
|
||||
" # Create a new directory for the model if it doesn't exist\n",
|
||||
" if not os.path.exists(name):\n",
|
||||
" os.mkdir(name)\n",
|
||||
" # Store the paths in each of their respective directories\n",
|
||||
" model_path = os.path.join(name, \"network.onnx\")\n",
|
||||
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
|
||||
" pk_path = os.path.join(name, \"test.pk\")\n",
|
||||
" vk_path = os.path.join(name, \"test.vk\")\n",
|
||||
" settings_path = os.path.join(name, \"settings.json\")\n",
|
||||
"\n",
|
||||
" witness_path = os.path.join(name, \"witness.json\")\n",
|
||||
" sol_code_path = os.path.join(name, 'test.sol')\n",
|
||||
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
|
||||
" abi_path = os.path.join(name, 'test.abi')\n",
|
||||
" proof_path = os.path.join(name, \"proof.json\")\n",
|
||||
"\n",
|
||||
" # Flips the neural net into inference mode\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
|
||||
" do_constant_folding=True, input_names=['input'],\n",
|
||||
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
" data = dict(input_data=[data_array])\n",
|
||||
" json.dump(data, open(input_path, 'w'))\n",
|
||||
"\n",
|
||||
" py_run_args = ezkl.PyRunArgs()\n",
|
||||
" py_run_args.input_visibility = \"private\"\n",
|
||||
" py_run_args.output_visibility = \"public\"\n",
|
||||
" py_run_args.param_visibility = \"fixed\" # private by default\n",
|
||||
"\n",
|
||||
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" await ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
|
||||
"\n",
|
||||
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" res = await ezkl.get_srs(settings_path)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" # now generate the witness file\n",
|
||||
" res = await ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
|
||||
" assert os.path.isfile(witness_path) == True\n",
|
||||
"\n",
|
||||
" # SETUP \n",
|
||||
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
|
||||
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
|
||||
" assert res == True\n",
|
||||
" assert os.path.isfile(vk_path)\n",
|
||||
" assert os.path.isfile(pk_path)\n",
|
||||
" assert os.path.isfile(settings_path)\n",
|
||||
"\n",
|
||||
" # GENERATE A PROOF\n",
|
||||
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
|
||||
" assert os.path.isfile(proof_path)\n",
|
||||
"\n",
|
||||
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" res = await ezkl.create_evm_vka(vk_path, settings_path, sol_key_code_path, abi_path)\n",
|
||||
" assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"RPC_URL = \"http://localhost:3030\"\n",
|
||||
"\n",
|
||||
"# Save process globally\n",
|
||||
"anvil_process = None\n",
|
||||
"\n",
|
||||
"def start_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is None:\n",
|
||||
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
|
||||
" if anvil_process.returncode is not None:\n",
|
||||
" raise Exception(\"failed to start anvil process\")\n",
|
||||
" time.sleep(3)\n",
|
||||
"\n",
|
||||
"def stop_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is not None:\n",
|
||||
" anvil_process.terminate()\n",
|
||||
" anvil_process = None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Check that the generated verifiers are identical for all models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_anvil()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import filecmp\n",
|
||||
"\n",
|
||||
"def compare_files(file1, file2):\n",
|
||||
" return filecmp.cmp(file1, file2, shallow=False)\n",
|
||||
"\n",
|
||||
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
|
||||
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
|
||||
"\n",
|
||||
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
|
||||
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
|
||||
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we deploy separate verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os \n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" \"verifier/reusable\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"with open(addr_path_verifier, 'r') as file:\n",
|
||||
" addr = file.read().rstrip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in names:\n",
|
||||
" addr_path_vk = \"addr_vk.txt\"\n",
|
||||
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
|
||||
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" with open(addr_path_vk, 'r') as file:\n",
|
||||
" addr_vk = file.read().rstrip()\n",
|
||||
" \n",
|
||||
" proof_path = os.path.join(name, \"proof.json\")\n",
|
||||
" sol_code_path = os.path.join(name, 'vk.sol')\n",
|
||||
" res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" addr_vk = addr_vk\n",
|
||||
" )\n",
|
||||
" assert res == True"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -171,7 +171,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,7 +328,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "171702d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 27,
|
||||
"id": "671dfdd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 28,
|
||||
"id": "50eba2f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -399,9 +399,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
763
examples/notebooks/univ3-da.ipynb
Normal file
763
examples/notebooks/univ3-da.ipynb
Normal file
@@ -0,0 +1,763 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# univ3-da-ezkl\n",
|
||||
"\n",
|
||||
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source. For this setup we make a single call to a view function that returns an array of UniV3 historical TWAP price data that we will attest to on-chain. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First we import the necessary dependencies and set up logging to be as informative as possible. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from torch import nn\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"# uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"# Defines the model\n",
|
||||
"\n",
|
||||
"class MyModel(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MyModel, self).__init__()\n",
|
||||
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.layer(x)[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = MyModel()\n",
|
||||
"\n",
|
||||
"# this is where you'd train your model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
|
||||
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
|
||||
"\n",
|
||||
"You can replace the random `x` with real data if you so wish. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w' ))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import time\n",
|
||||
"import threading\n",
|
||||
"\n",
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"RPC_URL = \"http://localhost:3030\"\n",
|
||||
"\n",
|
||||
"# Save process globally\n",
|
||||
"anvil_process = None\n",
|
||||
"\n",
|
||||
"def start_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is None:\n",
|
||||
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--fork-url\", \"https://arb1.arbitrum.io/rpc\", \"--code-size-limit=41943040\"])\n",
|
||||
" if anvil_process.returncode is not None:\n",
|
||||
" raise Exception(\"failed to start anvil process\")\n",
|
||||
" time.sleep(3)\n",
|
||||
"\n",
|
||||
"def stop_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is not None:\n",
|
||||
" anvil_process.terminate()\n",
|
||||
" anvil_process = None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
|
||||
"- `input_visibility` defines the visibility of the model inputs\n",
|
||||
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
|
||||
"- `output_visibility` defines the visibility of the model outputs\n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
"- `input_visibility`: \"public\"\n",
|
||||
"- `param_visibility`: \"private\"\n",
|
||||
"- `output_visibility`: public\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"srs_path = os.path.join('kzg.srs')\n",
|
||||
"data_path = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"private\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
|
||||
"\n",
|
||||
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate a bunch of dummy calibration data\n",
|
||||
"cal_data = {\n",
|
||||
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"cal_path = os.path.join('val_data.json')\n",
|
||||
"# save as json file\n",
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from web3 import Web3, HTTPProvider\n",
|
||||
"from solcx import compile_standard\n",
|
||||
"from decimal import Decimal\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# This function counts the decimal places of a floating point number\n",
|
||||
"def count_decimal_places(num):\n",
|
||||
" num_str = str(num)\n",
|
||||
" if '.' in num_str:\n",
|
||||
" return len(num_str) - 1 - num_str.index('.')\n",
|
||||
" else:\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"def set_next_block_timestamp(anvil_url, timestamp):\n",
|
||||
" # Send the JSON-RPC request to Anvil\n",
|
||||
" payload = {\n",
|
||||
" \"jsonrpc\": \"2.0\",\n",
|
||||
" \"id\": 1,\n",
|
||||
" \"method\": \"evm_setNextBlockTimestamp\",\n",
|
||||
" \"params\": [timestamp]\n",
|
||||
" }\n",
|
||||
" response = requests.post(anvil_url, json=payload)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" print(f\"Next block timestamp set to: {timestamp}\")\n",
|
||||
" else:\n",
|
||||
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
|
||||
"\n",
|
||||
"def on_chain_data(tensor):\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = tensor.view(-1).tolist()\n",
|
||||
"\n",
|
||||
" # Step 1: Prepare the calldata\n",
|
||||
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
|
||||
"\n",
|
||||
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
|
||||
" contract_source_code = '''\n",
|
||||
" // SPDX-License-Identifier: MIT\n",
|
||||
" pragma solidity ^0.8.20;\n",
|
||||
"\n",
|
||||
" /// @title Pool state that is not stored\n",
|
||||
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
|
||||
" /// blockchain. The functions here may have variable gas costs.\n",
|
||||
" interface IUniswapV3PoolDerivedState {\n",
|
||||
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
|
||||
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
|
||||
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
|
||||
" /// you must call it with secondsAgos = [3600, 0].\n",
|
||||
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
|
||||
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
|
||||
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
|
||||
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
|
||||
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
|
||||
" /// timestamp\n",
|
||||
" function observe(\n",
|
||||
" uint32[] calldata secondsAgos\n",
|
||||
" )\n",
|
||||
" external\n",
|
||||
" view\n",
|
||||
" returns (\n",
|
||||
" int56[] memory tickCumulatives,\n",
|
||||
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
|
||||
" );\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
|
||||
" /// @notice Provides functions to integrate with V3 pool oracle\n",
|
||||
" contract UniTickAttestor {\n",
|
||||
" /**\n",
|
||||
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
|
||||
" * @param pool Address of the pool that we want to observe\n",
|
||||
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
|
||||
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
|
||||
" */\n",
|
||||
" function consult(\n",
|
||||
" IUniswapV3PoolDerivedState pool,\n",
|
||||
" uint32[] memory secondsAgo\n",
|
||||
" ) public view returns (int256[] memory tickCumulatives) {\n",
|
||||
" tickCumulatives = new int256[](secondsAgo.length);\n",
|
||||
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
|
||||
" for (uint256 i = 0; i < secondsAgo.length; i++) {\n",
|
||||
" tickCumulatives[i] = int256(_ticks[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
" compiled_sol = compile_standard({\n",
|
||||
" \"language\": \"Solidity\",\n",
|
||||
" \"sources\": {\"UniTickAttestor.sol\": {\"content\": contract_source_code}},\n",
|
||||
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" # Get bytecode\n",
|
||||
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
|
||||
"\n",
|
||||
" # Get ABI\n",
|
||||
" # In production if you are reading from really large contracts you can just use\n",
|
||||
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
|
||||
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
|
||||
"\n",
|
||||
" # Step 3: Deploy the contract\n",
|
||||
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
|
||||
" tx_hash = UniTickAttestor.constructor().transact()\n",
|
||||
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
|
||||
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
|
||||
" # passing the address and abi of the contract you are fetching data from.\n",
|
||||
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
|
||||
"\n",
|
||||
" # Step 4: Interact with the contract\n",
|
||||
" call = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).build_transaction()\n",
|
||||
" result = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).call()\n",
|
||||
" \n",
|
||||
" print(f'result: {result}')\n",
|
||||
" calldata = call['data'][2:]\n",
|
||||
"\n",
|
||||
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
"\n",
|
||||
" print(f'time_stamp: {time_stamp}')\n",
|
||||
"\n",
|
||||
" # Set the next block timestamp using the fetched time_stamp\n",
|
||||
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Prepare the calls_to_account object\n",
|
||||
" # If you were calling view functions across multiple contracts,\n",
|
||||
" # you would have multiple entries in the calls_to_account array,\n",
|
||||
" # one for each contract.\n",
|
||||
" call_to_account = {\n",
|
||||
" 'call_data': calldata,\n",
|
||||
" 'decimals': 0,\n",
|
||||
" 'address': contract.address[2:], # remove the '0x' prefix\n",
|
||||
" 'len': len(data),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" print(f'call_to_account: {call_to_account}')\n",
|
||||
"\n",
|
||||
" return call_to_account\n",
|
||||
"\n",
|
||||
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
|
||||
"start_anvil()\n",
|
||||
"\n",
|
||||
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
|
||||
"\n",
|
||||
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we generate a full proof. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And verify it as a sanity check. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now create and then deploy a vanilla evm verifier."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"input_path = 'input.json'\n",
|
||||
"\n",
|
||||
"res = await ezkl.create_evm_data_attestation(\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
|
||||
"So should only be used for testing purposes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"addr_path_da = \"addr_da.txt\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we need to regenerate the witness, prove and then verify all within the same cell. This is because we want to reduce the amount of latency between reading on-chain state and verifying it on-chain. This is because the attest input values read from the oracle are time sensitive (their values are derived from computing on block.timestamp) and can change between the time of reading and the time of verifying.\n",
|
||||
"\n",
|
||||
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
"\n",
|
||||
"print(f'time_stamp: {time_stamp}')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"# read the verifier address\n",
|
||||
"addr_verifier = None\n",
|
||||
"with open(addr_path_verifier, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"#read the data attestation address\n",
|
||||
"addr_da = None\n",
|
||||
"with open(addr_path_da, 'r') as f:\n",
|
||||
" addr_da = f.read()\n",
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -9,7 +9,9 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, w, x, y, z):
|
||||
return [((x & y)) == (x & (y | (z ^ w)))]
|
||||
a = (x & y)
|
||||
b = (y & (z ^ w))
|
||||
return [a & b]
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]}
|
||||
{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]}
|
||||
@@ -1,21 +1,17 @@
|
||||
pytorch1.12.1:«
|
||||
+
|
||||
pytorch2.2.2:„
|
||||
*
|
||||
input1
|
||||
input2
|
||||
onnx::Equal_4And_0"And
|
||||
input2
|
||||
/And_output_0/And"And
|
||||
)
|
||||
input3
|
||||
input
|
||||
onnx::Or_5Xor_1"Xor
|
||||
input3
|
||||
input
|
||||
/Xor_output_0/Xor"Xor
|
||||
input2
|
||||
|
||||
onnx::Or_5onnx::And_6Or_2"Or
|
||||
0
|
||||
input1
|
||||
onnx::And_6
|
||||
onnx::Equal_7And_3"And
|
||||
6
|
||||
5
|
||||
input2
|
||||
|
||||
/Xor_output_0/And_1_output_0/And_1"And
|
||||
5
|
||||
|
||||
/And_output_0
|
||||
/And_1_output_0output/And_2"And
|
||||
|
||||
42
examples/onnx/exp/gen.py
Normal file
42
examples/onnx/exp/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.exp(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/exp/input.json
Normal file
1
examples/onnx/exp/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.5801457762718201, 0.6019012331962585, 0.8695418238639832, 0.17170941829681396, 0.500616729259491, 0.353726327419281, 0.6726185083389282, 0.5936906337738037]]}
|
||||
14
examples/onnx/exp/network.onnx
Normal file
14
examples/onnx/exp/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Exp"Exp
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
1
examples/onnx/fr_age/input.json
Normal file
1
examples/onnx/fr_age/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/fr_age/network.onnx
Normal file
BIN
examples/onnx/fr_age/network.onnx
Normal file
Binary file not shown.
41
examples/onnx/general_exp/gen.py
Normal file
41
examples/onnx/general_exp/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = 10**x
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/general_exp/input.json
Normal file
1
examples/onnx/general_exp/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.9837989807128906, 0.026381194591522217, 0.3403851389884949, 0.14531707763671875, 0.24652725458145142, 0.7945117354393005, 0.4076554775238037, 0.23064672946929932]]}
|
||||
BIN
examples/onnx/general_exp/network.onnx
Normal file
BIN
examples/onnx/general_exp/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/log/gen.py
Normal file
42
examples/onnx/log/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.log(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 3)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/log/input.json
Normal file
1
examples/onnx/log/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
|
||||
14
examples/onnx/log/network.onnx
Normal file
14
examples/onnx/log/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Log"Log
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -21,9 +21,9 @@ def main():
|
||||
torch_model = Circuit()
|
||||
# Input to the model
|
||||
shape = [3, 2, 3]
|
||||
w = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
x = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
y = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
torch_out = torch_model(w, x, y)
|
||||
# Export the model
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
|
||||
@@ -1 +1,148 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
|
||||
{
|
||||
"input_shapes": [
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
]
|
||||
],
|
||||
"input_data": [
|
||||
[
|
||||
0.5,
|
||||
1.5,
|
||||
-0.04514765739440918,
|
||||
0.5936200618743896,
|
||||
0.9271858930587769,
|
||||
0.6688600778579712,
|
||||
-0.20331168174743652,
|
||||
-0.7016235589981079,
|
||||
0.025863051414489746,
|
||||
-0.19426143169403076,
|
||||
0.9827852249145508,
|
||||
0.4897397756576538,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.9278832674026489,
|
||||
0.5943725109100342,
|
||||
-0.573331356048584,
|
||||
0.3675816059112549
|
||||
],
|
||||
[
|
||||
0.7803324460983276,
|
||||
-0.9616303443908691,
|
||||
0.6070173978805542,
|
||||
-0.028337717056274414,
|
||||
-0.5080242156982422,
|
||||
-0.9280107021331787,
|
||||
0.6150380373001099,
|
||||
0.3865993022918701,
|
||||
-0.43668973445892334,
|
||||
0.17152702808380127,
|
||||
0.5144252777099609,
|
||||
-0.28881049156188965,
|
||||
0.8932310342788696,
|
||||
0.059034109115600586,
|
||||
0.6865451335906982,
|
||||
0.009820222854614258,
|
||||
0.23011493682861328,
|
||||
-0.9492779970169067
|
||||
],
|
||||
[
|
||||
-0.21352827548980713,
|
||||
-0.16015326976776123,
|
||||
-0.38964390754699707,
|
||||
0.13464701175689697,
|
||||
-0.8814496994018555,
|
||||
0.5037975311279297,
|
||||
-0.804405927658081,
|
||||
0.9858957529067993,
|
||||
0.19567716121673584,
|
||||
0.9777265787124634,
|
||||
0.6151977777481079,
|
||||
0.568595290184021,
|
||||
0.10584986209869385,
|
||||
-0.8975653648376465,
|
||||
0.6235959529876709,
|
||||
-0.547879695892334,
|
||||
0.9289869070053101,
|
||||
0.7567293643951416
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.0
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0
|
||||
],
|
||||
[
|
||||
-0.0,
|
||||
-0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
]
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
pytorch2.0.1:â
|
||||
pytorch2.2.2:ă
|
||||
|
||||
woutput_w/Round"Round
|
||||
|
||||
xoutput_x/Floor"Floor
|
||||
|
||||
youtput_y/Ceil"Ceil torch_jitZ%
|
||||
youtput_y/Ceil"Ceil
|
||||
main_graphZ%
|
||||
w
|
||||
|
||||
|
||||
|
||||
42
examples/onnx/rsqrt/gen.py
Normal file
42
examples/onnx/rsqrt/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# reciprocal sqrt
|
||||
m = 1 / torch.sqrt(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/rsqrt/input.json
Normal file
1
examples/onnx/rsqrt/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}
|
||||
17
examples/onnx/rsqrt/network.onnx
Normal file
17
examples/onnx/rsqrt/network.onnx
Normal file
@@ -0,0 +1,17 @@
|
||||
pytorch2.2.2:Ź
|
||||
$
|
||||
input/Sqrt_output_0/Sqrt"Sqrt
|
||||
1
|
||||
/Sqrt_output_0output/Reciprocal"
|
||||
Reciprocal
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -1,75 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
from torch import nn
|
||||
import json
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
# gather_nd in tf then export to onnx
|
||||
x = in1 = Input((4, 1), dtype=tf.int32)
|
||||
w = in2 = Input((4, ), dtype=tf.int32)
|
||||
|
||||
class MyLayer(Layer):
|
||||
def call(self, x, w):
|
||||
shape = tf.constant([8])
|
||||
return tf.scatter_nd(x, w, shape)
|
||||
|
||||
x = MyLayer()(x, w)
|
||||
|
||||
|
||||
|
||||
tm = Model((in1, in2), x)
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 4, 1]
|
||||
index_shape = [1, 4]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = np.random.randint(0, 4, shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 4, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
input_data=[d, d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
|
||||
@@ -1 +1,16 @@
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
1
|
||||
]
|
||||
]
|
||||
}
|
||||
Binary file not shown.
851
ezkl.pyi
Normal file
851
ezkl.pyi
Normal file
@@ -0,0 +1,851 @@
|
||||
# This file is automatically generated by pyo3_stub_gen
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import typing
|
||||
from enum import Enum, auto
|
||||
|
||||
class PyG1:
|
||||
r"""
|
||||
pyclass containing the struct used for G1, this is mostly a helper class
|
||||
"""
|
||||
...
|
||||
|
||||
class PyG1Affine:
|
||||
r"""
|
||||
pyclass containing the struct used for G1
|
||||
"""
|
||||
...
|
||||
|
||||
class PyRunArgs:
|
||||
r"""
|
||||
Python class containing the struct used for run_args
|
||||
|
||||
Returns
|
||||
-------
|
||||
PyRunArgs
|
||||
"""
|
||||
...
|
||||
|
||||
class PyCommitments(Enum):
|
||||
r"""
|
||||
pyclass representing an enum, denoting the type of commitment
|
||||
"""
|
||||
KZG = auto()
|
||||
IPA = auto()
|
||||
|
||||
class PyInputType(Enum):
|
||||
Bool = auto()
|
||||
F16 = auto()
|
||||
F32 = auto()
|
||||
F64 = auto()
|
||||
Int = auto()
|
||||
TDim = auto()
|
||||
|
||||
class PyTestDataSource(Enum):
|
||||
r"""
|
||||
pyclass representing an enum
|
||||
"""
|
||||
File = auto()
|
||||
OnChain = auto()
|
||||
|
||||
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Creates an aggregated proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
proof_path: str
|
||||
Path to output the aggregated proof
|
||||
|
||||
vk_path: str
|
||||
Path to the VK file
|
||||
|
||||
transcript:
|
||||
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
|
||||
logrows:
|
||||
Logrows used for aggregation circuit
|
||||
|
||||
check_mode: str
|
||||
Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
|
||||
split-proofs: bool
|
||||
Whether the accumulated proofs are segments of a larger circuit
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS used
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
|
||||
r"""
|
||||
Converts a buffer to vector of field elements
|
||||
|
||||
Arguments
|
||||
-------
|
||||
buffer: list[int]
|
||||
List of integers representing a buffer
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List of field elements represented as strings
|
||||
"""
|
||||
...
|
||||
|
||||
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
|
||||
r"""
|
||||
Calibrates the circuit settings
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the calibration data
|
||||
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
settings: str
|
||||
Path to the settings file
|
||||
|
||||
lookup_safety_margin: int
|
||||
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
|
||||
scales: list[int]
|
||||
Optional scales to specifically try for calibration
|
||||
|
||||
scale_rebase_multiplier: list[int]
|
||||
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
|
||||
|
||||
max_logrows: int
|
||||
Optional max logrows to use for calibration
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Compiles the circuit for use in other steps
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx model file
|
||||
|
||||
compiled_circuit: str
|
||||
Path to output the compiled circuit
|
||||
|
||||
settings_path: str
|
||||
Path to the settings files
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
input_data: str
|
||||
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifier
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifier
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_settings: str
|
||||
path to the settings file
|
||||
|
||||
vk_path: str
|
||||
The path to load the desired verification key file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the Solidity code
|
||||
|
||||
abi_path: str
|
||||
The path to output the Solidity verifier ABI
|
||||
|
||||
logrows: int
|
||||
Number of logrows used during aggregated setup
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifying key.
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
deploys the solidity da verifier
|
||||
"""
|
||||
...
|
||||
|
||||
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
deploys the solidity verifier
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
|
||||
r"""
|
||||
Creates encoded evm calldata from a proof file
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof: str
|
||||
Path to the proof file
|
||||
|
||||
calldata: str
|
||||
Path to the calldata file to save
|
||||
|
||||
addr_vk: str
|
||||
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
|
||||
|
||||
Returns
|
||||
-------
|
||||
vec[u8]
|
||||
The encoded calldata
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_big_endian(felt:str) -> str:
|
||||
r"""
|
||||
Converts a field element hex string to big endian
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
field element represented as a string
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_float(felt:str,scale:int) -> float:
|
||||
r"""
|
||||
Converts a field element hex string to a floating point number
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
scale: float
|
||||
The scaling factor used to convert the field element into a floating point representation
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_int(felt:str) -> int:
|
||||
r"""
|
||||
Converts a field element hex string to an integer
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
...
|
||||
|
||||
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
|
||||
r"""
|
||||
Converts a floating point element to a field element hex string
|
||||
|
||||
Arguments
|
||||
-------
|
||||
input: float
|
||||
The field element represented as a string
|
||||
|
||||
scale: float
|
||||
The scaling factor used to quantize the float into a field element
|
||||
|
||||
input_type: PyInputType
|
||||
The type of the input
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The field element represented as a string
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
|
||||
r"""
|
||||
Generates the circuit settings
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
output: str
|
||||
Path to create the settings file
|
||||
|
||||
py_run_args: PyRunArgs
|
||||
PyRunArgs object to initialize the settings
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
|
||||
r"""
|
||||
Generates the Structured Reference String (SRS), use this only for testing purposes
|
||||
|
||||
Arguments
|
||||
---------
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for a model circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
circuit_settings_path: str
|
||||
Path to the witness file
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the forward pass operation to generate a witness
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the data file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
output: str
|
||||
Path to create the witness file
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Python object containing the witness values
|
||||
"""
|
||||
...
|
||||
|
||||
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
|
||||
r"""
|
||||
Gets a public srs
|
||||
|
||||
Arguments
|
||||
---------
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
commitment: str
|
||||
Specify the commitment used ("kzg", "ipa")
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate an ipa commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate a kzg commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Mocks the prover
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
|
||||
r"""
|
||||
Mocks the aggregate prover
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the relevant proof files
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use for the aggregation circuit
|
||||
|
||||
split_proofs: bool
|
||||
Indicates whether the accumulated are segments of a larger proof
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
|
||||
r"""
|
||||
Generate a poseidon hash.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represented as strings
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List of field elements represented as strings
|
||||
"""
|
||||
...
|
||||
|
||||
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the prover on a set of inputs
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
pk_path: str
|
||||
Path to the proving key file
|
||||
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
proof_type: str
|
||||
Accepts `single`, `for-aggr`
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
|
||||
r"""
|
||||
Runs the setup process
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
vk_path: str
|
||||
Path to create the verification key file
|
||||
|
||||
pk_path: str
|
||||
Path to create the proving key file
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress the selectors or not
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Runs the setup process for an aggregate setup
|
||||
|
||||
Arguments
|
||||
---------
|
||||
sample_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
vk_path: str
|
||||
Path to create the aggregated VK
|
||||
|
||||
pk_path: str
|
||||
Path to create the aggregated PK
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use
|
||||
|
||||
split_proofs: bool
|
||||
Whether the accumulated are segments of a larger proof
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress selectors
|
||||
|
||||
commitment: str
|
||||
Accepts `kzg`, `ipa`
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
Setup test evm witness
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data_path: str
|
||||
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
|
||||
compiled_circuit_path: str
|
||||
The path to the compiled model file (generated using the compile-circuit command)
|
||||
|
||||
test_data: str
|
||||
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
|
||||
input_sources: str
|
||||
Where the input data comes from
|
||||
|
||||
output_source: str
|
||||
Where the output data comes from
|
||||
|
||||
rpc_url: str
|
||||
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
|
||||
r"""
|
||||
Swap the commitments in a proof
|
||||
|
||||
Arguments
|
||||
-------
|
||||
proof_path: str
|
||||
Path to the proof file
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
"""
|
||||
...
|
||||
|
||||
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
|
||||
r"""
|
||||
Displays the table as a string in python
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
Returns
|
||||
---------
|
||||
str
|
||||
Table of the nodes in the onnx file
|
||||
"""
|
||||
...
|
||||
|
||||
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
|
||||
r"""
|
||||
Verifies a given proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key file
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
non_reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
|
||||
r"""
|
||||
Verifies and aggregate proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
The path to the proof file
|
||||
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
logrows: int
|
||||
logrows used for aggregation circuit
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
addr_verifier: str
|
||||
The verifier contract's address as a hex string
|
||||
|
||||
proof_path: str
|
||||
The path to the proof file (generated using the prove command)
|
||||
|
||||
rpc_url: str
|
||||
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
|
||||
addr_da: str
|
||||
does the verifier use data attestation ?
|
||||
|
||||
addr_vk: str
|
||||
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -9,7 +9,6 @@ import { EVM } from '@ethereumjs/evm'
|
||||
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
|
||||
import { getAccountNonce, insertAccount } from './utils/account-utils'
|
||||
import { encodeVerifierCalldata } from '../nodejs/ezkl';
|
||||
import { error } from 'console'
|
||||
|
||||
async function deployContract(
|
||||
vm: VM,
|
||||
@@ -66,7 +65,7 @@ async function verify(
|
||||
vkAddress = new Uint8Array(uint8Array.buffer);
|
||||
|
||||
// convert uitn8array of length
|
||||
error('vkAddress', vkAddress)
|
||||
console.error('vkAddress', vkAddress)
|
||||
}
|
||||
const data = encodeVerifierCalldata(proof, vkAddress)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ asyncio_mode = "auto"
|
||||
|
||||
[project]
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.7"
|
||||
classifiers = [
|
||||
"Programming Language :: Rust",
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
// ignore file if compiling for wasm
|
||||
#[global_allocator]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use ezkl::commands::Cli;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use ezkl::execute::run;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use ezkl::logger::init_logger;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::{error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(feature = "icicle")]
|
||||
use std::env;
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub async fn main() {
|
||||
let args = Cli::parse();
|
||||
|
||||
@@ -59,7 +59,7 @@ pub async fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
pub fn main() {}
|
||||
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
|
||||
269
src/bin/ios_gen_bindings.rs
Normal file
269
src/bin/ios_gen_bindings.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
use camino::Utf8Path;
|
||||
use std::fs;
|
||||
use std::fs::remove_dir_all;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use uniffi_bindgen::bindings::SwiftBindingGenerator;
|
||||
use uniffi_bindgen::library_mode::generate_bindings;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn main() {
|
||||
let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set");
|
||||
let mode = determine_build_mode();
|
||||
build_bindings(&library_name, mode);
|
||||
}
|
||||
|
||||
/// Determines the build mode based on the CONFIGURATION environment variable.
|
||||
/// Defaults to "release" if not set or unrecognized.
|
||||
/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster.
|
||||
fn determine_build_mode() -> &'static str {
|
||||
match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) {
|
||||
Ok(ref config) if config == "debug" => "debug",
|
||||
_ => "release",
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds the Swift bindings and XCFramework for the specified library and build mode.
|
||||
fn build_bindings(library_name: &str, mode: &str) {
|
||||
// Get the root directory of this Cargo project
|
||||
let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| std::env::current_dir().unwrap());
|
||||
|
||||
// Define the build directory inside the manifest directory
|
||||
let build_dir = manifest_dir.join("build");
|
||||
|
||||
// Create a temporary directory to store the bindings and combined library
|
||||
let tmp_dir = mktemp_local(&build_dir);
|
||||
|
||||
// Define directories for Swift bindings and output bindings
|
||||
let swift_bindings_dir = tmp_dir.join("SwiftBindings");
|
||||
let bindings_out = create_bindings_out_dir(&tmp_dir);
|
||||
let framework_out = bindings_out.join("EzklCore.xcframework");
|
||||
|
||||
// Define target architectures for building
|
||||
// We currently only support iOS devices and simulators running on ARM Macs
|
||||
// This is due to limiting the library size to under 100MB for GitHub Commit Size Limit
|
||||
// To support older Macs (Intel), follow the instructions in the comments below
|
||||
#[allow(clippy::useless_vec)]
|
||||
let target_archs = vec![
|
||||
vec!["aarch64-apple-ios"], // iOS device
|
||||
vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac
|
||||
// vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel)
|
||||
];
|
||||
|
||||
// Build the library for each architecture and combine them
|
||||
let out_lib_paths: Vec<PathBuf> = target_archs
|
||||
.iter()
|
||||
.map(|archs| build_combined_archs(library_name, archs, &build_dir, mode))
|
||||
.collect();
|
||||
|
||||
// Generate the path to the built dynamic library (.dylib)
|
||||
let out_dylib_path = build_dir.join(format!(
|
||||
"{}/{}/lib{}.dylib",
|
||||
target_archs[0][0], mode, library_name
|
||||
));
|
||||
|
||||
// Generate Swift bindings using uniffi_bindgen
|
||||
generate_ios_bindings(&out_dylib_path, &swift_bindings_dir)
|
||||
.expect("Failed to generate iOS bindings");
|
||||
|
||||
// Move the generated Swift file to the bindings output directory
|
||||
fs::rename(
|
||||
swift_bindings_dir.join(format!("{}.swift", library_name)),
|
||||
bindings_out.join("EzklCore.swift"),
|
||||
)
|
||||
.expect("Failed to copy swift bindings file");
|
||||
|
||||
// Rename the `ios_ezklFFI.modulemap` file to `module.modulemap`
|
||||
fs::rename(
|
||||
swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)),
|
||||
swift_bindings_dir.join("module.modulemap"),
|
||||
)
|
||||
.expect("Failed to rename modulemap file");
|
||||
|
||||
// Create the XCFramework from the combined libraries and Swift bindings
|
||||
create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out);
|
||||
|
||||
// Define the destination directory for the bindings
|
||||
let bindings_dest = build_dir.join("EzklCoreBindings");
|
||||
if bindings_dest.exists() {
|
||||
fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory");
|
||||
}
|
||||
|
||||
// Move the bindings output to the destination directory
|
||||
fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place");
|
||||
|
||||
// Clean up temporary directories
|
||||
cleanup_temp_dirs(&build_dir);
|
||||
}
|
||||
|
||||
/// Creates the output directory for the bindings.
|
||||
/// Returns the path to the bindings output directory.
|
||||
fn create_bindings_out_dir(base_dir: &Path) -> PathBuf {
|
||||
let bindings_out = base_dir.join("EzklCoreBindings");
|
||||
fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory");
|
||||
bindings_out
|
||||
}
|
||||
|
||||
/// Builds the library for each architecture and combines them into a single library using lipo.
|
||||
/// Returns the path to the combined library.
|
||||
fn build_combined_archs(
|
||||
library_name: &str,
|
||||
archs: &[&str],
|
||||
build_dir: &Path,
|
||||
mode: &str,
|
||||
) -> PathBuf {
|
||||
// Build the library for each architecture
|
||||
let out_lib_paths: Vec<PathBuf> = archs
|
||||
.iter()
|
||||
.map(|&arch| {
|
||||
build_for_arch(arch, build_dir, mode);
|
||||
build_dir
|
||||
.join(arch)
|
||||
.join(mode)
|
||||
.join(format!("lib{}.a", library_name))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create a unique temporary directory for the combined library
|
||||
let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name));
|
||||
|
||||
// Combine the libraries using lipo
|
||||
let mut lipo_cmd = Command::new("lipo");
|
||||
lipo_cmd
|
||||
.arg("-create")
|
||||
.arg("-output")
|
||||
.arg(lib_out.to_str().unwrap());
|
||||
for lib_path in &out_lib_paths {
|
||||
lipo_cmd.arg(lib_path.to_str().unwrap());
|
||||
}
|
||||
|
||||
let status = lipo_cmd.status().expect("Failed to run lipo command");
|
||||
if !status.success() {
|
||||
panic!("lipo command failed with status: {}", status);
|
||||
}
|
||||
|
||||
lib_out
|
||||
}
|
||||
|
||||
/// Builds the library for a specific architecture.
|
||||
fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) {
|
||||
// Ensure the target architecture is installed
|
||||
install_arch(arch);
|
||||
|
||||
// Run cargo build for the specified architecture and mode
|
||||
let mut build_cmd = Command::new("cargo");
|
||||
build_cmd
|
||||
.arg("build")
|
||||
.arg("--no-default-features")
|
||||
.arg("--features")
|
||||
.arg("ios-bindings");
|
||||
|
||||
if mode == "release" {
|
||||
build_cmd.arg("--release");
|
||||
}
|
||||
build_cmd
|
||||
.arg("--lib")
|
||||
.env("CARGO_BUILD_TARGET_DIR", build_dir)
|
||||
.env("CARGO_BUILD_TARGET", arch);
|
||||
|
||||
let status = build_cmd.status().expect("Failed to run cargo build");
|
||||
if !status.success() {
|
||||
panic!("cargo build failed for architecture: {}", arch);
|
||||
}
|
||||
}
|
||||
|
||||
/// Installs the specified target architecture using rustup.
|
||||
fn install_arch(arch: &str) {
|
||||
let status = Command::new("rustup")
|
||||
.arg("target")
|
||||
.arg("add")
|
||||
.arg(arch)
|
||||
.status()
|
||||
.expect("Failed to run rustup command");
|
||||
|
||||
if !status.success() {
|
||||
panic!("Failed to install target architecture: {}", arch);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates Swift bindings for the iOS library using uniffi_bindgen.
|
||||
fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> {
|
||||
// Remove existing binding directory if it exists
|
||||
if binding_dir.exists() {
|
||||
remove_dir_all(binding_dir)?;
|
||||
}
|
||||
|
||||
// Generate the Swift bindings using uniffi_bindgen
|
||||
generate_bindings(
|
||||
Utf8Path::from_path(dylib_path).ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path")
|
||||
})?,
|
||||
None,
|
||||
&SwiftBindingGenerator,
|
||||
None,
|
||||
Utf8Path::from_path(binding_dir).ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"Invalid Swift bindings directory",
|
||||
)
|
||||
})?,
|
||||
true,
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates an XCFramework from the combined libraries and Swift bindings.
|
||||
fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) {
|
||||
let mut xcbuild_cmd = Command::new("xcodebuild");
|
||||
xcbuild_cmd.arg("-create-xcframework");
|
||||
|
||||
// Add each library and its corresponding headers to the xcodebuild command
|
||||
for lib_path in lib_paths {
|
||||
println!("Including library: {:?}", lib_path);
|
||||
xcbuild_cmd.arg("-library");
|
||||
xcbuild_cmd.arg(lib_path.to_str().unwrap());
|
||||
xcbuild_cmd.arg("-headers");
|
||||
xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap());
|
||||
}
|
||||
|
||||
xcbuild_cmd.arg("-output");
|
||||
xcbuild_cmd.arg(framework_out.to_str().unwrap());
|
||||
|
||||
let status = xcbuild_cmd.status().expect("Failed to run xcodebuild");
|
||||
if !status.success() {
|
||||
panic!("xcodebuild failed with status: {}", status);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a temporary directory inside the build path with a unique UUID.
|
||||
/// This ensures unique build artifacts for concurrent builds.
|
||||
fn mktemp_local(build_path: &Path) -> PathBuf {
|
||||
let dir = tmp_local(build_path).join(Uuid::new_v4().to_string());
|
||||
fs::create_dir(&dir).expect("Failed to create temporary directory");
|
||||
dir
|
||||
}
|
||||
|
||||
/// Gets the path to the local temporary directory inside the build path.
|
||||
fn tmp_local(build_path: &Path) -> PathBuf {
|
||||
let tmp_path = build_path.join("tmp");
|
||||
if let Ok(metadata) = fs::metadata(&tmp_path) {
|
||||
if !metadata.is_dir() {
|
||||
panic!("Expected 'tmp' to be a directory");
|
||||
}
|
||||
} else {
|
||||
fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory");
|
||||
}
|
||||
tmp_path
|
||||
}
|
||||
|
||||
/// Cleans up temporary directories inside the build path.
|
||||
fn cleanup_temp_dirs(build_dir: &Path) {
|
||||
let tmp_dir = build_dir.join("tmp");
|
||||
if tmp_dir.exists() {
|
||||
fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories");
|
||||
}
|
||||
}
|
||||
9
src/bin/py_stub_gen.rs
Normal file
9
src/bin/py_stub_gen.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use pyo3_stub_gen::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// `stub_info` is a function defined by `define_stub_info_gatherer!` macro.
|
||||
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
|
||||
let stub = ezkl::bindings::python::stub_info()?;
|
||||
stub.generate()?;
|
||||
Ok(())
|
||||
}
|
||||
12
src/bindings/mod.rs
Normal file
12
src/bindings/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
/// Python bindings
|
||||
#[cfg(feature = "python-bindings")]
|
||||
pub mod python;
|
||||
/// Universal bindings for all platforms
|
||||
#[cfg(any(
|
||||
feature = "ios-bindings",
|
||||
all(target_arch = "wasm32", target_os = "unknown")
|
||||
))]
|
||||
pub mod universal;
|
||||
/// wasm prover and verifier
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub mod wasm;
|
||||
@@ -4,6 +4,7 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
@@ -26,7 +27,12 @@ use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use pyo3_stub_gen::{
|
||||
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction, TypeInfo,
|
||||
};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::collections::HashSet;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
@@ -35,6 +41,7 @@ type PyFelt = String;
|
||||
/// pyclass representing an enum
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
enum PyTestDataSource {
|
||||
/// The data is loaded from a file
|
||||
File,
|
||||
@@ -54,6 +61,7 @@ impl From<PyTestDataSource> for TestDataSource {
|
||||
/// pyclass containing the struct used for G1, this is mostly a helper class
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyG1 {
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing x
|
||||
@@ -100,6 +108,7 @@ impl pyo3::ToPyObject for PyG1 {
|
||||
/// pyclass containing the struct used for G1
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
pub struct PyG1Affine {
|
||||
#[pyo3(get, set)]
|
||||
///
|
||||
@@ -145,6 +154,7 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
///
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
@@ -180,9 +190,6 @@ struct PyRunArgs {
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
@@ -191,6 +198,15 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// str: commitment type, accepts `kzg`, `ipa`
|
||||
pub commitment: PyCommitments,
|
||||
/// int: The base used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_base: usize,
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -206,6 +222,7 @@ impl PyRunArgs {
|
||||
impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -217,10 +234,11 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,6 +246,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
@@ -239,16 +258,18 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
/// pyclass representing an enum, denoting the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
@@ -296,6 +317,65 @@ impl FromStr for PyCommitments {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
enum PyInputType {
|
||||
///
|
||||
Bool,
|
||||
///
|
||||
F16,
|
||||
///
|
||||
F32,
|
||||
///
|
||||
F64,
|
||||
///
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
}
|
||||
|
||||
impl From<InputType> for PyInputType {
|
||||
fn from(input_type: InputType) -> Self {
|
||||
match input_type {
|
||||
InputType::Bool => PyInputType::Bool,
|
||||
InputType::F16 => PyInputType::F16,
|
||||
InputType::F32 => PyInputType::F32,
|
||||
InputType::F64 => PyInputType::F64,
|
||||
InputType::Int => PyInputType::Int,
|
||||
InputType::TDim => PyInputType::TDim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyInputType> for InputType {
|
||||
fn from(py_input_type: PyInputType) -> Self {
|
||||
match py_input_type {
|
||||
PyInputType::Bool => InputType::Bool,
|
||||
PyInputType::F16 => InputType::F16,
|
||||
PyInputType::F32 => InputType::F32,
|
||||
PyInputType::F64 => InputType::F64,
|
||||
PyInputType::Int => InputType::Int,
|
||||
PyInputType::TDim => InputType::TDim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PyInputType {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"bool" => Ok(PyInputType::Bool),
|
||||
"f16" => Ok(PyInputType::F16),
|
||||
"f32" => Ok(PyInputType::F32),
|
||||
"f64" => Ok(PyInputType::F64),
|
||||
"int" => Ok(PyInputType::Int),
|
||||
"tdim" => Ok(PyInputType::TDim),
|
||||
_ => Err("Invalid value for InputType".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to big endian
|
||||
///
|
||||
/// Arguments
|
||||
@@ -312,6 +392,7 @@ impl FromStr for PyCommitments {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
Ok(format!("{:?}", felt))
|
||||
@@ -331,6 +412,7 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
@@ -355,6 +437,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
felt,
|
||||
scale
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
@@ -373,6 +456,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
/// scale: float
|
||||
/// The scaling factor used to quantize the float into a field element
|
||||
///
|
||||
/// input_type: PyInputType
|
||||
/// The type of the input
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
@@ -380,9 +466,12 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input,
|
||||
scale
|
||||
scale,
|
||||
input_type=PyInputType::F64
|
||||
))]
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
#[gen_stub_pyfunction]
|
||||
fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult<PyFelt> {
|
||||
InputType::roundtrip(&input_type.into(), &mut input);
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
@@ -404,6 +493,7 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
#[pyfunction(signature = (
|
||||
buffer
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
@@ -476,6 +566,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
@@ -521,6 +612,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn kzg_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -579,6 +671,7 @@ fn kzg_commit(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -625,6 +718,7 @@ fn ipa_commit(
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
witness_path=PathBuf::from(DEFAULT_WITNESS),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
|
||||
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
|
||||
@@ -654,6 +748,7 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
|
||||
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_vk_from_pk_single(
|
||||
path_to_pk: PathBuf,
|
||||
circuit_settings_path: PathBuf,
|
||||
@@ -691,6 +786,7 @@ fn gen_vk_from_pk_single(
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
@@ -720,6 +816,7 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
py_run_args = None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
|
||||
let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?;
|
||||
@@ -745,6 +842,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
srs_path,
|
||||
logrows,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
let params = ezkl_gen_srs::<KZGCommitmentScheme<Bn256>>(logrows as u32);
|
||||
save_params::<KZGCommitmentScheme<Bn256>>(&srs_path, ¶ms)?;
|
||||
@@ -777,6 +875,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
srs_path=None,
|
||||
commitment=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn get_srs(
|
||||
py: Python,
|
||||
settings_path: Option<PathBuf>,
|
||||
@@ -789,7 +888,7 @@ fn get_srs(
|
||||
None => None,
|
||||
};
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -823,6 +922,7 @@ fn get_srs(
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
py_run_args = None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_settings(
|
||||
model: PathBuf,
|
||||
output: PathBuf,
|
||||
@@ -838,6 +938,45 @@ fn gen_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Generates random data for the model
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the data file
|
||||
///
|
||||
/// seed: int
|
||||
/// Random seed to use for generated data
|
||||
///
|
||||
/// variables
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
variables=Vec::from([("batch_size".to_string(), 1)]),
|
||||
seed=DEFAULT_SEED.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_random_data(
|
||||
model: PathBuf,
|
||||
output: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
|
||||
let err_str = format!("Failed to generate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Calibrates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
@@ -863,8 +1002,6 @@ fn gen_settings(
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -879,8 +1016,8 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
@@ -891,9 +1028,8 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -902,7 +1038,6 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -946,6 +1081,7 @@ fn calibrate_settings(
|
||||
vk_path=None,
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_witness(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
@@ -954,7 +1090,7 @@ fn gen_witness(
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -983,6 +1119,7 @@ fn gen_witness(
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
crate::execute::mock(model, witness).map_err(|e| {
|
||||
let err_str = format!("Failed to run mock: {}", e);
|
||||
@@ -1013,6 +1150,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn mock_aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
@@ -1060,6 +1198,7 @@ fn mock_aggregate(
|
||||
witness_path = None,
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup(
|
||||
model: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
@@ -1118,6 +1257,7 @@ fn setup(
|
||||
proof_type=ProofType::default(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn prove(
|
||||
witness: PathBuf,
|
||||
model: PathBuf,
|
||||
@@ -1173,6 +1313,7 @@ fn prove(
|
||||
srs_path=None,
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1232,6 +1373,7 @@ fn verify(
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
@@ -1282,6 +1424,7 @@ fn setup_aggregate(
|
||||
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn compile_circuit(
|
||||
model: PathBuf,
|
||||
compiled_circuit: PathBuf,
|
||||
@@ -1341,6 +1484,7 @@ fn compile_circuit(
|
||||
srs_path=None,
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
proof_path: PathBuf,
|
||||
@@ -1406,6 +1550,7 @@ fn aggregate(
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify_aggr(
|
||||
proof_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
@@ -1453,6 +1598,7 @@ fn verify_aggr(
|
||||
calldata=PathBuf::from(DEFAULT_CALLDATA),
|
||||
addr_vk=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn encode_evm_calldata<'a>(
|
||||
proof: PathBuf,
|
||||
calldata: PathBuf,
|
||||
@@ -1490,8 +1636,8 @@ fn encode_evm_calldata<'a>(
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
|
||||
/// reusable: bool
|
||||
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1503,8 +1649,9 @@ fn encode_evm_calldata<'a>(
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
@@ -1512,16 +1659,16 @@ fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1533,7 +1680,8 @@ fn create_evm_verifier(
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
|
||||
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
@@ -1563,7 +1711,8 @@ fn create_evm_verifier(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None
|
||||
))]
|
||||
fn create_evm_vk(
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_vka(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1571,8 +1720,8 @@ fn create_evm_vk(
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
@@ -1610,6 +1759,7 @@ fn create_evm_vk(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
|
||||
witness_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_data_attestation(
|
||||
py: Python,
|
||||
input_data: PathBuf,
|
||||
@@ -1618,7 +1768,7 @@ fn create_evm_data_attestation(
|
||||
abi_path: PathBuf,
|
||||
witness_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_data_attestation(
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -1670,6 +1820,7 @@ fn create_evm_data_attestation(
|
||||
output_source,
|
||||
rpc_url=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_test_evm_witness(
|
||||
py: Python,
|
||||
data_path: PathBuf,
|
||||
@@ -1679,7 +1830,7 @@ fn setup_test_evm_witness(
|
||||
output_source: PyTestDataSource,
|
||||
rpc_url: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_witness(
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
@@ -1703,60 +1854,28 @@ fn setup_test_evm_witness(
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
rpc_url=None,
|
||||
contract_type=ContractType::default(),
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn deploy_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
contract_type: ContractType,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// deploys the solidity vk verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_vk_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
contract_type,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1778,6 +1897,7 @@ fn deploy_vk_evm(
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn deploy_da_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
@@ -1788,7 +1908,7 @@ fn deploy_da_evm(
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_da_evm(
|
||||
input_data,
|
||||
settings_path,
|
||||
@@ -1836,6 +1956,7 @@ fn deploy_da_evm(
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify_evm<'a>(
|
||||
py: Python<'a>,
|
||||
addr_verifier: &'a str,
|
||||
@@ -1858,7 +1979,7 @@ fn verify_evm<'a>(
|
||||
None
|
||||
};
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1892,8 +2013,8 @@ fn verify_evm<'a>(
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
/// reusable: bool
|
||||
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1906,8 +2027,9 @@ fn verify_evm<'a>(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
srs_path=None,
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python,
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
@@ -1916,9 +2038,9 @@ fn create_evm_verifier_aggr(
|
||||
abi_path: PathBuf,
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -1926,7 +2048,7 @@ fn create_evm_verifier_aggr(
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1938,15 +2060,19 @@ fn create_evm_verifier_aggr(
|
||||
})
|
||||
}
|
||||
|
||||
// Define a function to gather stub information.
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init();
|
||||
m.add_class::<PyRunArgs>()?;
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add_class::<PyCommitments>()?;
|
||||
m.add_class::<PyInputType>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
@@ -1968,6 +2094,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
|
||||
@@ -1975,9 +2102,8 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
@@ -1986,3 +2112,48 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for CalibrationTarget {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for ProofType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for TranscriptType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for CheckMode {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for ContractType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
580
src/bindings/universal.rs
Normal file
580
src/bindings/universal.rs
Normal file
@@ -0,0 +1,580 @@
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::{
|
||||
commitment::{CommitmentScheme, ParamsProver},
|
||||
ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
multiopen::{ProverIPA, VerifierIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
},
|
||||
kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
},
|
||||
VerificationStrategy,
|
||||
},
|
||||
};
|
||||
use std::fmt::Display;
|
||||
use std::io::BufReader;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
circuit::region::RegionSettings,
|
||||
graph::GraphSettings,
|
||||
pfsys::{
|
||||
create_proof_circuit,
|
||||
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
|
||||
verify_proof_circuit, TranscriptType,
|
||||
},
|
||||
tensor::TensorType,
|
||||
CheckMode, Commitments, EZKLError as InnerEZKLError,
|
||||
};
|
||||
|
||||
use crate::graph::{GraphCircuit, GraphWitness};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::{FromUniformBytes, PrimeField},
|
||||
};
|
||||
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
|
||||
|
||||
/// Wrapper around the Error Message
|
||||
#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))]
|
||||
#[derive(Debug)]
|
||||
pub enum EZKLError {
|
||||
/// Some Comment
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl Display for EZKLError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EZKLError::InternalError(e) => write!(f, "Internal error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InnerEZKLError> for EZKLError {
|
||||
fn from(e: InnerEZKLError) -> Self {
|
||||
EZKLError::InternalError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode verifier calldata from proof and ethereum vk_address
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn encode_verifier_calldata(
|
||||
// TODO - shuold it be pub(crate) or pub or pub(super)?
|
||||
proof: Vec<u8>,
|
||||
vk_address: Option<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let snark: crate::pfsys::Snark<Fr, G1Affine> =
|
||||
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address {
|
||||
let array: [u8; 20] =
|
||||
serde_json::from_slice(&vk_address[..]).map_err(InnerEZKLError::from)?;
|
||||
Some(array)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let flattened_instances = snark.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(
|
||||
vk_address,
|
||||
&snark.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
Ok(encoded)
|
||||
}
|
||||
|
||||
/// Generate witness from compiled circuit and input json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
|
||||
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
|
||||
})?;
|
||||
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
|
||||
|
||||
let mut input = circuit
|
||||
.load_graph_input(&input)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(
|
||||
circuit.settings().run_args.decomp_base,
|
||||
circuit.settings().run_args.decomp_legs,
|
||||
),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
|
||||
}
|
||||
|
||||
/// Generate verifying key from compiled circuit, and parameters srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn gen_vk(
|
||||
compiled_circuit: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
|
||||
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let vk = create_vk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
|
||||
|
||||
let mut serialized_vk = Vec::new();
|
||||
vk.write(
|
||||
&mut serialized_vk,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(serialized_vk)
|
||||
}
|
||||
|
||||
/// Generate proving key from vk, compiled circuit and parameters srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn gen_pk(
|
||||
vk: Vec<u8>,
|
||||
compiled_circuit: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
|
||||
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
|
||||
let pk = create_pk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk, &circuit, ¶ms)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?;
|
||||
|
||||
let mut serialized_pk = Vec::new();
|
||||
pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?;
|
||||
|
||||
Ok(serialized_pk)
|
||||
}
|
||||
|
||||
/// Verify proof with vk, proof json, circuit settings json and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn verify(
|
||||
proof: Vec<u8>,
|
||||
vk: Vec<u8>,
|
||||
settings: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?;
|
||||
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let orig_n = 1 << circuit_settings.run_args.logrows;
|
||||
let commitment = circuit_settings.run_args.commitment.into();
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let result = match commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> = get_params(&mut reader)?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => Err(EZKLError::InternalError(format!(
|
||||
"Verification failed: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify aggregate proof with vk, proof, circuit settings and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn verify_aggr(
|
||||
proof: Vec<u8>,
|
||||
vk: Vec<u8>,
|
||||
logrows: u64,
|
||||
srs: Vec<u8>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, EZKLError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let commit = Commitments::from_str(commitment)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?;
|
||||
|
||||
let orig_n = 1 << logrows;
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let result = match commit {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)),
|
||||
)?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
result
|
||||
.map(|_| true)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))
|
||||
}
|
||||
|
||||
/// Prove in browser with compiled circuit, witness json, proving key, and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn prove(
|
||||
witness: Vec<u8>,
|
||||
pk: Vec<u8>,
|
||||
compiled_circuit: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
#[cfg(feature = "det-prove")]
|
||||
log::set_max_level(log::LevelFilter::Debug);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
|
||||
let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
circuit
|
||||
.load_graph_witness(&data)
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
let public_inputs = circuit
|
||||
.prepare_public_inputs(&data)
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let commitment = circuit.settings().run_args.commitment.into();
|
||||
|
||||
let proof = match commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
|
||||
)?;
|
||||
|
||||
create_proof_circuit::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
|
||||
)?;
|
||||
|
||||
create_proof_circuit::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
ProverIPA<_>,
|
||||
VerifierIPA<_>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
Commitments::IPA,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?)
|
||||
}
|
||||
|
||||
/// Validate the witness json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the compiled circuit
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the input json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: crate::graph::input::GraphData =
|
||||
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the proof json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: crate::pfsys::Snark<Fr, G1Affine> =
|
||||
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the verifying key given the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings =
|
||||
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the proving key given the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings =
|
||||
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub(crate) fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let _: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize params: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
|
||||
fn get_params<
|
||||
Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>,
|
||||
>(
|
||||
mut reader: &mut BufReader<&[u8]>,
|
||||
) -> Result<Scheme, EZKLError> {
|
||||
halo2_proofs::poly::commitment::Params::<G1Affine>::read(&mut reader)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)))
|
||||
}
|
||||
|
||||
/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
|
||||
pub fn create_vk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
compress_selectors: bool,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize the verifying key
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
|
||||
Ok(vk)
|
||||
}
|
||||
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
|
||||
pub fn create_pk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
vk: VerifyingKey<Scheme::Curve>,
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize the proving key
|
||||
let pk = keygen_pk(params, vk, &empty_circuit)?;
|
||||
Ok(pk)
|
||||
}
|
||||
379
src/bindings/wasm.rs
Normal file
379
src/bindings/wasm.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use crate::{
|
||||
circuit::modules::{
|
||||
polycommit::PolyCommitChip,
|
||||
poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
},
|
||||
Module,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
};
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::PrimeField,
|
||||
};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
use crate::bindings::universal::{
|
||||
compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness,
|
||||
input_validation, pk_validation, proof_validation, settings_validation, srs_validation,
|
||||
verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError,
|
||||
};
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
|
||||
impl From<ExternalEZKLError> for JsError {
|
||||
fn from(e: ExternalEZKLError) -> Self {
|
||||
JsError::new(&format!("{}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
/// Initialize logger for wasm
|
||||
pub fn init_logger() {
|
||||
log::set_logger(&DEFAULT_LOGGER).unwrap();
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
/// Initialize panic hook for wasm
|
||||
pub fn init_panic_hook() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Wrapper around the halo2 encode call data method
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn encodeVerifierCalldata(
|
||||
proof: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk_address: Option<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Converts a hex string to a byte array
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(format!("{:?}", felt))
|
||||
}
|
||||
|
||||
/// Converts a felt to a little endian string
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let repr = serde_json::to_string(&felt).unwrap();
|
||||
let b: String = serde_json::from_str(&repr).unwrap();
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
/// Converts a hex string to a byte array
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToInt(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_integer_rep(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Converts felts to a floating point element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToFloat(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
scale: crate::Scale,
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
|
||||
/// Converts a floating point number to a hex string representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn floatToFelt(
|
||||
mut input: f64,
|
||||
scale: crate::Scale,
|
||||
input_type: &str,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
crate::circuit::InputType::roundtrip(
|
||||
&crate::circuit::InputType::from_str(input_type)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
&mut input,
|
||||
);
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Generate a kzg commitment.
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn kzgCommit(
|
||||
message: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(¶ms_ser[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn bufferToVecOfFelt(
|
||||
buffer: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
// Convert the buffer to a slice
|
||||
let buffer: &[u8] = &buffer;
|
||||
|
||||
// Divide the buffer into chunks of 64 bytes
|
||||
let chunks = buffer.chunks_exact(16);
|
||||
|
||||
// Get the remainder
|
||||
let remainder = chunks.remainder();
|
||||
|
||||
// Add 0s to the remainder to make it 64 bytes
|
||||
let mut remainder = remainder.to_vec();
|
||||
|
||||
// Collect chunks into a Vec<[u8; 16]>.
|
||||
let chunks: Result<Vec<[u8; 16]>, JsError> = chunks
|
||||
.map(|slice| {
|
||||
let array: [u8; 16] = slice
|
||||
.try_into()
|
||||
.map_err(|_| JsError::new("failed to slice input chunks"))?;
|
||||
Ok(array)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut chunks = chunks?;
|
||||
|
||||
if remainder.len() != 0 {
|
||||
remainder.resize(16, 0);
|
||||
// Convert the Vec<u8> to [u8; 16]
|
||||
let remainder_array: [u8; 16] = remainder
|
||||
.try_into()
|
||||
.map_err(|_| JsError::new("failed to slice remainder"))?;
|
||||
// append the remainder to the chunks
|
||||
chunks.push(remainder_array);
|
||||
}
|
||||
|
||||
// Convert each chunk to a field element
|
||||
let field_elements: Vec<Fr> = chunks
|
||||
.iter()
|
||||
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
|
||||
.collect();
|
||||
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&field_elements)
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate a poseidon hash in browser. Input message
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn poseidonHash(
|
||||
message: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)),
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Generate a witness file from input.json, compiled model and a settings.json file.
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genWitness(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
input: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_witness(compiled_circuit.0, input.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Generate verifying key in browser
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genVk(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Generate proving key in browser
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genPk(
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Verify proof in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
pub fn verify(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Verify aggregate proof in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn verifyAggr(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
logrows: u64,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, JsError> {
|
||||
verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Prove in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
pub fn prove(
|
||||
witness: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
pk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
// VALIDATION FUNCTIONS
|
||||
|
||||
/// Witness file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn witnessValidation(witness: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
witness_validation(witness.0).map_err(JsError::from)
|
||||
}
|
||||
/// Compiled circuit validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn compiledCircuitValidation(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from)
|
||||
}
|
||||
/// Input file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn inputValidation(input: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
input_validation(input.0).map_err(JsError::from)
|
||||
}
|
||||
/// Proof file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn proofValidation(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
proof_validation(proof.0).map_err(JsError::from)
|
||||
}
|
||||
/// Vk file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn vkValidation(
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
vk_validation(vk.0, settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Pk file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn pkValidation(
|
||||
pk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
pk_validation(pk.0, settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Settings file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn settingsValidation(settings: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
settings_validation(settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Srs file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
srs_validation(srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// HELPER FUNCTIONS
|
||||
pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
for &b in arr.iter().rev() {
|
||||
n <<= 8;
|
||||
n |= b as u128;
|
||||
}
|
||||
n
|
||||
}
|
||||
@@ -219,7 +219,7 @@ mod tests {
|
||||
fn polycommit_chip_for_a_range_of_input_sizes() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
{
|
||||
@@ -247,7 +247,7 @@ mod tests {
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn polycommit_chip_much_longer_input() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Self::configure_with_cols(
|
||||
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
let instance = meta.instance_column();
|
||||
@@ -560,7 +554,7 @@ mod tests {
|
||||
fn hash_for_a_range_of_input_sizes() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
{
|
||||
|
||||
@@ -8,12 +8,12 @@ use halo2_proofs::{
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
conversion::{FromPyObject, IntoPy},
|
||||
exceptions::PyValueError,
|
||||
prelude::*,
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
@@ -49,6 +49,7 @@ impl std::fmt::Display for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for CheckMode {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
@@ -74,6 +75,16 @@ impl FromStr for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckMode {
|
||||
/// Returns the value of the check mode
|
||||
pub fn is_safe(&self) -> bool {
|
||||
match self {
|
||||
CheckMode::SAFE => true,
|
||||
CheckMode::UNSAFE => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
@@ -88,6 +99,7 @@ impl std::fmt::Display for Tolerance {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
@@ -136,10 +148,9 @@ impl IntoPy<PyObject> for CheckMode {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for CheckMode {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let trystr = String::extract_bound(ob)?;
|
||||
match trystr.to_lowercase().as_str() {
|
||||
"safe" => Ok(CheckMode::SAFE),
|
||||
"unsafe" => Ok(CheckMode::UNSAFE),
|
||||
_ => Err(PyValueError::new_err("Invalid value for CheckMode")),
|
||||
@@ -158,8 +169,8 @@ impl IntoPy<PyObject> for Tolerance {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = ob.extract::<(f32, f32)>() {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
@@ -174,7 +185,7 @@ impl<'source> FromPyObject<'source> for Tolerance {
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
pub lookup_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub table_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
@@ -204,15 +215,16 @@ impl DynamicLookups {
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
pub output_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub references: Vec<VarTensor>,
|
||||
pub outputs: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
@@ -223,9 +235,13 @@ impl Shuffles {
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
output_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
outputs: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -363,6 +379,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
if inputs[0].num_cols() != output.num_cols() {
|
||||
log::warn!("input and output shapes do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
|
||||
log::warn!("input number of inner columns do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != output.num_inner_cols() {
|
||||
log::warn!("input and output number of inner columns do not match");
|
||||
}
|
||||
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
@@ -570,9 +592,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* table
|
||||
* (table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -604,6 +626,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -643,57 +699,73 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of inner columns
|
||||
if tables
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
.windows(2)
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"tables inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_ltable = cs.complex_selector();
|
||||
for q in 0..tables[0].num_blocks() {
|
||||
let s_ltable = cs.complex_selector();
|
||||
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookups
|
||||
.lookup_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_lookup);
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookups
|
||||
.lookup_selectors
|
||||
.entry((q, (x, y)))
|
||||
.or_insert(s_lookup);
|
||||
}
|
||||
}
|
||||
|
||||
self.dynamic_lookups.table_selectors.push(s_ltable);
|
||||
}
|
||||
self.dynamic_lookups.table_selectors.push(s_ltable);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.dynamic_lookups.tables.is_empty() {
|
||||
@@ -713,8 +785,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
inputs: &[VarTensor; 3],
|
||||
outputs: &[VarTensor; 3],
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
@@ -725,63 +797,78 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
for t in outputs.iter() {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if outputs
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
.windows(2)
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"outputs inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_reference = cs.complex_selector();
|
||||
for q in 0..outputs[0].num_blocks() {
|
||||
let s_output = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
cs.lookup_any("shuffle", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_outputq = cs.query_selector(s_output);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
input_queries.push(match input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
for input in inputs {
|
||||
input_queries.push(match input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
let mut output_queries = vec![one.clone()];
|
||||
for output in outputs {
|
||||
output_queries.push(match output {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.shuffles
|
||||
.input_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_input);
|
||||
expression
|
||||
});
|
||||
self.shuffles
|
||||
.input_selectors
|
||||
.entry((q, (x, y)))
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.output_selectors.push(s_output);
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
if self.shuffles.outputs.is_empty() {
|
||||
debug!("assigning shuffles output");
|
||||
self.shuffles.outputs = outputs.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
@@ -851,9 +938,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* range_check
|
||||
* (range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -876,6 +963,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -94,4 +94,13 @@ pub enum CircuitError {
|
||||
#[error("[io] {0}")]
|
||||
/// IO error
|
||||
IoError(#[from] std::io::Error),
|
||||
/// Invalid scale
|
||||
#[error("negative scale for an op that requires positive inputs {0}")]
|
||||
NegativeScale(String),
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -13,14 +13,38 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
@@ -45,6 +69,8 @@ pub enum HybridOp {
|
||||
ReduceArgMin {
|
||||
dim: usize,
|
||||
},
|
||||
Max,
|
||||
Min,
|
||||
Softmax {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -71,12 +97,19 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
@@ -88,21 +121,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
|
||||
HybridOp::Max => "MAX".to_string(),
|
||||
HybridOp::Min => "MIN".to_string(),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -135,10 +179,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".into(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".into(),
|
||||
HybridOp::Less => "LESS".into(),
|
||||
HybridOp::LessEqual => "LESSEQUAL".into(),
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
HybridOp::Less => "LESS".to_string(),
|
||||
HybridOp::LessEqual => "LESSEQUAL".to_string(),
|
||||
HybridOp::Equals => "EQUALS".into(),
|
||||
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
|
||||
HybridOp::TopK { k, dim, largest } => {
|
||||
@@ -157,6 +201,34 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Floor { scale, legs } => {
|
||||
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Round { scale, legs } => {
|
||||
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -174,42 +246,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::loop_div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
integer_rep_to_felt(denom.0 as IntegerRep),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -296,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
|
||||
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
output_scale,
|
||||
input_scale,
|
||||
..
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -15,116 +14,27 @@ use halo2curves::ff::PrimeField;
|
||||
/// An enum representing the operations that can be used to express more complex operations via accumulation
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Abs,
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ReLU,
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
GreaterThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThan {
|
||||
a: utils::F32,
|
||||
},
|
||||
GreaterThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
LessThanEqual {
|
||||
a: utils::F32,
|
||||
},
|
||||
Sign,
|
||||
KroneckerDelta,
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Div { denom: utils::F32 },
|
||||
IsOdd,
|
||||
PowersOfTwo { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Exp { scale: utils::F32, base: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
HardSwish { scale: utils::F32 },
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -138,34 +48,14 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "abs".into(),
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Sign => "sign".into(),
|
||||
LookupOp::LessThan { a } => format!("less_than_{}", a),
|
||||
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
|
||||
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::ReLU => "relu".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
|
||||
LookupOp::IsOdd => "is_odd".to_string(),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale, base } => format!("exp_{}_{}", scale, base),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
@@ -188,91 +78,70 @@ impl LookupOp {
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())),
|
||||
LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())),
|
||||
LookupOp::RoundHalfToEven { scale } => Ok(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)),
|
||||
LookupOp::Max { scale, a } => Ok(tensor::ops::nonlinearities::max(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::Min { scale, a } => Ok(tensor::ops::nonlinearities::min(
|
||||
&x,
|
||||
scale.0.into(),
|
||||
a.0.into(),
|
||||
)),
|
||||
LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)),
|
||||
LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than(
|
||||
&x,
|
||||
f32::from(*a).into(),
|
||||
)),
|
||||
LookupOp::GreaterThanEqual { a } => Ok(
|
||||
tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div(
|
||||
&x,
|
||||
f32::from(*denom).into(),
|
||||
)),
|
||||
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
|
||||
&x,
|
||||
f32::from(*scale).into(),
|
||||
)),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
|
||||
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())),
|
||||
LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())),
|
||||
LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())),
|
||||
LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())),
|
||||
LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())),
|
||||
LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())),
|
||||
LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())),
|
||||
LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())),
|
||||
LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())),
|
||||
LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())),
|
||||
LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())),
|
||||
LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())),
|
||||
LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())),
|
||||
LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())),
|
||||
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
|
||||
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
|
||||
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::PowersOfTwo { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
|
||||
}
|
||||
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Erf { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Exp { scale, base } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::exp(&x, scale.into(), base.into()),
|
||||
),
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ACos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::acos(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cosh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cosh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ACosh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::acosh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sin { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sin(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ASin { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::asin(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sinh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sinh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ASinh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::asinh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Tan { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::tan(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ATan { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::atan(&x, scale.into()))
|
||||
}
|
||||
LookupOp::ATanh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::atanh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Tanh { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::tanh(&x, scale.into()))
|
||||
}
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| integer_rep_to_felt(x));
|
||||
|
||||
@@ -289,37 +158,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "ABS".into(),
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Sign => "SIGN".into(),
|
||||
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
|
||||
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
|
||||
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::ReLU => "RELU".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
|
||||
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
|
||||
LookupOp::IsOdd => "IS_ODD".to_string(),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
|
||||
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
|
||||
LookupOp::Exp { scale, base } => format!("EXP(scale={}, base={})", scale, base),
|
||||
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
|
||||
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
|
||||
LookupOp::Tanh { scale } => format!("TANH(scale={})", scale),
|
||||
@@ -352,20 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::Sign
|
||||
| LookupOp::GreaterThan { .. }
|
||||
| LookupOp::LessThan { .. }
|
||||
| LookupOp::GreaterThanEqual { .. }
|
||||
| LookupOp::LessThanEqual { .. }
|
||||
| LookupOp::KroneckerDelta => 0,
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
let scale = inputs_scale[0];
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
|
||||
@@ -105,7 +105,10 @@ impl InputType {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
|
||||
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone + std::fmt::Debug>(
|
||||
&self,
|
||||
input: &mut T,
|
||||
) {
|
||||
match self {
|
||||
InputType::Bool => {
|
||||
let boolean_input = input.clone().to_i64().unwrap();
|
||||
@@ -118,7 +121,7 @@ impl InputType {
|
||||
*input = T::from_f32(f32_input).unwrap();
|
||||
}
|
||||
InputType::F32 => {
|
||||
let f32_input = input.clone().to_f32().unwrap();
|
||||
let f32_input: f32 = input.clone().to_f32().unwrap();
|
||||
*input = T::from_f32(f32_input).unwrap();
|
||||
}
|
||||
InputType::F64 => {
|
||||
@@ -133,6 +136,22 @@ impl InputType {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for InputType {
|
||||
type Err = CircuitError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"bool" => Ok(InputType::Bool),
|
||||
"f16" => Ok(InputType::F16),
|
||||
"f32" => Ok(InputType::F32),
|
||||
"f64" => Ok(InputType::F64),
|
||||
"int" => Ok(InputType::Int),
|
||||
"tdim" => Ok(InputType::TDim),
|
||||
e => Err(CircuitError::InvalidInputType(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Input {
|
||||
@@ -255,7 +274,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
self.raw_values = Tensor::new(None, &[0]).unwrap();
|
||||
}
|
||||
|
||||
///
|
||||
/// Pre-assign a value
|
||||
pub fn pre_assign(&mut self, val: ValTensor<F>) {
|
||||
self.pre_assigned_val = Some(val)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
circuit::{
|
||||
layouts,
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -9,6 +12,12 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
Abs,
|
||||
Sign,
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
scale: i32,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
@@ -99,8 +108,7 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
,
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
@@ -110,6 +118,9 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
|
||||
PolyOp::Abs => "ABS".to_string(),
|
||||
PolyOp::Sign => "SIGN".to_string(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
@@ -191,6 +202,11 @@ impl<
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::LeakyReLU { slope, scale } => {
|
||||
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
@@ -321,6 +337,12 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
// this corresponds to the relu operation
|
||||
PolyOp::LeakyReLU {
|
||||
slope: F32(0.0), ..
|
||||
} => in_scales[0],
|
||||
// this corresponds to the leaky relu operation with a slope which induces a change in scale
|
||||
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
@@ -368,6 +390,7 @@ impl<
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
|
||||
PolyOp::Sign { .. } => 0,
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
@@ -93,6 +93,10 @@ pub struct RegionSettings {
|
||||
pub witness_gen: bool,
|
||||
/// whether we should check range checks for validity
|
||||
pub check_range: bool,
|
||||
/// base for decompositions
|
||||
pub base: usize,
|
||||
/// number of legs for decompositions
|
||||
pub legs: usize,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
@@ -102,26 +106,32 @@ unsafe impl Send for RegionSettings {}
|
||||
|
||||
impl RegionSettings {
|
||||
/// Create a new region settings
|
||||
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
|
||||
pub fn new(witness_gen: bool, check_range: bool, base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen,
|
||||
check_range,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all true
|
||||
pub fn all_true() -> RegionSettings {
|
||||
pub fn all_true(base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: true,
|
||||
check_range: true,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all false
|
||||
pub fn all_false() -> RegionSettings {
|
||||
pub fn all_false(base: usize, legs: usize) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: false,
|
||||
check_range: false,
|
||||
base,
|
||||
legs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,14 +180,30 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
max_dynamic_input_len: usize,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// get the region's decomposition base
|
||||
pub fn base(&self) -> usize {
|
||||
self.settings.base
|
||||
}
|
||||
|
||||
/// get the region's decomposition legs
|
||||
pub fn legs(&self) -> usize {
|
||||
self.settings.legs
|
||||
}
|
||||
|
||||
/// get the max dynamic input len
|
||||
pub fn max_dynamic_input_len(&self) -> usize {
|
||||
self.max_dynamic_input_len
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
///
|
||||
pub fn debug_report(&self) {
|
||||
log::debug!(
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={}, max_dynamic_input_len={})",
|
||||
self.row().to_string().blue(),
|
||||
self.linear_coord().to_string().yellow(),
|
||||
self.total_constants().to_string().red(),
|
||||
@@ -185,7 +211,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.min_lookup_inputs().to_string().green(),
|
||||
self.max_range_size().to_string().green(),
|
||||
self.dynamic_lookup_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green());
|
||||
self.shuffle_col_coord().to_string().green(),
|
||||
self.max_dynamic_input_len().to_string().green()
|
||||
);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -203,6 +231,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.dynamic_lookup_index.index += n;
|
||||
}
|
||||
|
||||
/// increment the max dynamic input len
|
||||
pub fn update_max_dynamic_input_len(&mut self, n: usize) {
|
||||
self.max_dynamic_input_len = self.max_dynamic_input_len.max(n);
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.col_coord += n;
|
||||
@@ -234,7 +267,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
pub fn new(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = Some(RefCell::new(region));
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -246,8 +285,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,9 +296,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
constants: ConstantsMap<F>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols);
|
||||
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
@@ -282,6 +324,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +346,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
max: IntegerRep,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
@@ -555,9 +610,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(ValTensor<F>, usize), CircuitError> {
|
||||
self.update_max_dynamic_input_len(values.len());
|
||||
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign(
|
||||
Ok(var.assign_exact_column(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
@@ -568,7 +625,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
|
||||
let flush_len = var.get_column_flush(self.combined_dynamic_shuffle_coord(), values)?;
|
||||
|
||||
// get the diff between the current column and the next row
|
||||
Ok((values.clone(), flush_len))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,7 +638,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(ValTensor<F>, usize), CircuitError> {
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
@@ -610,22 +671,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication(
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
single_inner_col: bool,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication(
|
||||
let (res, len) = var.assign_with_duplication_unconstrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
@@ -634,7 +690,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
single_inner_col,
|
||||
false,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_constrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_constrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
true,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::{
|
||||
tensor::{Tensor, TensorType},
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::execute::EZKL_REPO_PATH;
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
@@ -28,14 +28,14 @@ pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
lazy_static::lazy_static! {
|
||||
/// an optional directory to read and write the lookup table cache
|
||||
pub static ref LOOKUP_CACHE: String = format!("{}/cache", *EZKL_REPO_PATH);
|
||||
}
|
||||
|
||||
/// The lookup table cache is disabled on wasm32 target.
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
pub const LOOKUP_CACHE: &str = "";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -132,30 +132,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get largest element represented by the range
|
||||
pub fn largest(&self) -> IntegerRep {
|
||||
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
|
||||
}
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
self.largest()
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
@@ -222,7 +221,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
let largest = self.largest();
|
||||
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
@@ -291,6 +290,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
row_offset += chunk_idx * self.col_size;
|
||||
let (x, y) = self.cartesian_coord(row_offset);
|
||||
|
||||
if !preassigned_input {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
@@ -350,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -8,6 +8,10 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::Fr as F;
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
#[cfg(not(any(
|
||||
all(target_arch = "wasm32", target_os = "unknown"),
|
||||
not(feature = "ezkl")
|
||||
)))]
|
||||
use ops::lookup::LookupOp;
|
||||
use ops::region::RegionCtx;
|
||||
use rand::rngs::OsRng;
|
||||
@@ -55,7 +59,7 @@ mod matmul {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -132,7 +136,7 @@ mod matmul_col_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -206,7 +210,7 @@ mod matmul_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -243,7 +247,10 @@ mod matmul_col_overflow {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
mod matmul_col_ultra_overflow_double_col {
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
@@ -256,7 +263,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 20;
|
||||
const LEN: usize = 10;
|
||||
const NUM_INNER_COLS: usize = 2;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -290,7 +297,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -361,7 +368,10 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
@@ -374,7 +384,7 @@ mod matmul_col_ultra_overflow {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 20;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -407,7 +417,7 @@ mod matmul_col_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -518,7 +528,7 @@ mod dot {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -595,7 +605,7 @@ mod dot_col_overflow_triple_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 3);
|
||||
let mut region = RegionCtx::new(region, 0, 3, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -668,7 +678,7 @@ mod dot_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -741,7 +751,7 @@ mod sum {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -811,7 +821,7 @@ mod sum_col_overflow_double_col {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -880,7 +890,7 @@ mod sum_col_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -951,7 +961,7 @@ mod composition {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
let _ = config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1030,6 +1040,10 @@ mod conv {
|
||||
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
|
||||
// column for constants
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1042,7 +1056,7 @@ mod conv {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1144,7 +1158,10 @@ mod conv {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
mod conv_col_ultra_overflow {
|
||||
|
||||
use halo2_proofs::poly::{
|
||||
@@ -1158,8 +1175,8 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -1178,9 +1195,10 @@ mod conv_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1193,7 +1211,7 @@ mod conv_col_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1285,7 +1303,10 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
#[cfg(test)]
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
@@ -1297,8 +1318,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -1317,15 +1338,23 @@ mod conv_relu_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4);
|
||||
let mut base_config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// sets up a new relu table
|
||||
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU)
|
||||
.configure_range_check(cs, &a, &b, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (0, 1), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
base_config.clone()
|
||||
}
|
||||
|
||||
@@ -1334,12 +1363,12 @@ mod conv_relu_col_ultra_overflow {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
let output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
@@ -1355,7 +1384,10 @@ mod conv_relu_col_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(LookupOp::ReLU),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -1476,7 +1508,7 @@ mod add_w_shape_casting {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1492,7 +1524,7 @@ mod add_w_shape_casting {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
let b = Tensor::from((0..1).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
@@ -1543,7 +1575,7 @@ mod add {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1627,7 +1659,7 @@ mod dynamic_lookup {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::dynamic_lookup(
|
||||
&config,
|
||||
@@ -1749,13 +1781,18 @@ mod shuffle {
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.configure_shuffles(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
@@ -1769,7 +1806,7 @@ mod shuffle {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::shuffles(
|
||||
&config,
|
||||
@@ -1884,7 +1921,7 @@ mod add_with_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1986,7 +2023,7 @@ mod add_with_overflow_and_poseidon {
|
||||
layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.base
|
||||
.layout(&mut region, &inputs, Box::new(PolyOp::Add))
|
||||
@@ -2092,7 +2129,7 @@ mod sub {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2159,7 +2196,7 @@ mod mult {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2226,7 +2263,7 @@ mod pow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -2258,7 +2295,6 @@ mod matmul_relu {
|
||||
|
||||
const K: usize = 18;
|
||||
const LEN: usize = 32;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -2288,11 +2324,17 @@ mod matmul_relu {
|
||||
|
||||
let mut base_config =
|
||||
BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
|
||||
// sets up a new relu table
|
||||
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU)
|
||||
.configure_range_check(cs, &a, &b, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
base_config
|
||||
.configure_range_check(cs, &a, &b, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
MyConfig { base_config }
|
||||
}
|
||||
|
||||
@@ -2301,11 +2343,14 @@ mod matmul_relu {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.base_config.layout_tables(&mut layouter).unwrap();
|
||||
config
|
||||
.base_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let op = PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
};
|
||||
@@ -2315,7 +2360,14 @@ mod matmul_relu {
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -2354,6 +2406,8 @@ mod relu {
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
const K: u32 = 8;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
@@ -2370,16 +2424,26 @@ mod relu {
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.map(|_| VarTensor::new_advice(cs, 8, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
let mut config = BaseConfig::configure(
|
||||
cs,
|
||||
&[advices[0].clone(), advices[1].clone()],
|
||||
&advices[2],
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
config
|
||||
.configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl)
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K as usize)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
.configure_range_check(cs, &advices[0], &advices[1], (0, 1), K as usize)
|
||||
.unwrap();
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K as usize, 8, false);
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
@@ -2388,15 +2452,22 @@ mod relu {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>, // layouter is our 'write buffer' for the circuit
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
config.layout_range_checks(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.map_err(|_| Error::Synthesis)
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
Ok(config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@@ -2414,13 +2485,16 @@ mod relu {
|
||||
input: ValTensor::from(input),
|
||||
};
|
||||
|
||||
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
|
||||
let prover = MockProver::run(K, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
mod lookup_ultra_overflow {
|
||||
use super::*;
|
||||
use halo2_proofs::{
|
||||
@@ -2435,11 +2509,11 @@ mod lookup_ultra_overflow {
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ReLUCircuit<F> {
|
||||
impl Circuit<F> for SigmoidCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
@@ -2453,7 +2527,7 @@ mod lookup_ultra_overflow {
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::ReLU;
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
|
||||
@@ -2481,9 +2555,13 @@ mod lookup_ultra_overflow {
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -2495,13 +2573,13 @@ mod lookup_ultra_overflow {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn relucircuit() {
|
||||
fn sigmoidcircuit() {
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
// parameters
|
||||
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = ReLUCircuit::<F> {
|
||||
let circuit = SigmoidCircuit::<F> {
|
||||
input: ValTensor::from(a),
|
||||
};
|
||||
|
||||
@@ -2511,7 +2589,7 @@ mod lookup_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
ReLUCircuit<F>,
|
||||
SigmoidCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -141,23 +141,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn f32_eq() {
|
||||
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
|
||||
assert!(F32(std::f32::NAN) != F32(5.0));
|
||||
assert!(F32(5.0) != F32(std::f32::NAN));
|
||||
assert!(F32(f32::NAN) == F32(f32::NAN));
|
||||
assert!(F32(f32::NAN) != F32(5.0));
|
||||
assert!(F32(5.0) != F32(f32::NAN));
|
||||
assert!(F32(0.0) == F32(-0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f32_cmp() {
|
||||
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
|
||||
assert!(F32(std::f32::NAN) < F32(5.0));
|
||||
assert!(F32(5.0) > F32(std::f32::NAN));
|
||||
assert!(F32(f32::NAN) == F32(f32::NAN));
|
||||
assert!(F32(f32::NAN) < F32(5.0));
|
||||
assert!(F32(5.0) > F32(f32::NAN));
|
||||
assert!(F32(0.0) == F32(-0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f32_hash() {
|
||||
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
|
||||
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
|
||||
assert!(calculate_hash(&F32(f32::NAN)) == calculate_hash(&F32(-f32::NAN)));
|
||||
}
|
||||
}
|
||||
|
||||
275
src/commands.rs
275
src/commands.rs
@@ -1,14 +1,8 @@
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use alloy::primitives::Address as H160;
|
||||
use clap::{Command, Parser, Subcommand};
|
||||
use clap_complete::{generate, Generator, Shell};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
exceptions::PyValueError,
|
||||
prelude::*,
|
||||
types::PyString,
|
||||
};
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
@@ -17,7 +11,6 @@ use tosubcommand::{ToFlags, ToSubcommand};
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::pfsys::TranscriptType;
|
||||
|
||||
@@ -81,20 +74,24 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
|
||||
/// Default render vk separately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default render reusable verifier
|
||||
pub const DEFAULT_RENDER_REUSABLE: &str = "false";
|
||||
/// Default contract deployment type
|
||||
pub const DEFAULT_CONTRACT_DEPLOYMENT_TYPE: &str = "verifier";
|
||||
/// Default VK sol path
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
/// Default seed used to generate random data
|
||||
pub const DEFAULT_SEED: &str = "21242";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -109,8 +106,8 @@ impl IntoPy<PyObject> for TranscriptType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for TranscriptType {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let trystr = String::extract_bound(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
"poseidon" => Ok(TranscriptType::Poseidon),
|
||||
@@ -181,28 +178,80 @@ impl From<&str> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// Determines what type of contract (verifier, verifier/reusable, vka) should be deployed
|
||||
pub enum ContractType {
|
||||
/// Deploys a verifier contrat tailored to the circuit and not reusable
|
||||
Verifier {
|
||||
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
reusable: bool,
|
||||
},
|
||||
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
|
||||
VerifyingKeyArtifact,
|
||||
}
|
||||
|
||||
impl Default for ContractType {
|
||||
fn default() -> Self {
|
||||
ContractType::Verifier { reusable: false }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ContractType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
"verifier/reusable".to_string()
|
||||
}
|
||||
ContractType::Verifier { reusable: false } => "verifier".to_string(),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_string(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for ContractType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ContractType {
|
||||
fn from(s: &str) -> Self {
|
||||
match s {
|
||||
"verifier" => ContractType::Verifier { reusable: false },
|
||||
"verifier/reusable" => ContractType::Verifier { reusable: true },
|
||||
"vka" => ContractType::VerifyingKeyArtifact,
|
||||
_ => {
|
||||
log::error!("Invalid value for ContractType");
|
||||
log::warn!("Defaulting to verifier");
|
||||
ContractType::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
inner: H160,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<H160Flag> for H160 {
|
||||
fn from(val: H160Flag) -> H160 {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl ToFlags for H160Flag {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{:#x}", self.inner)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<&str> for H160Flag {
|
||||
fn from(s: &str) -> Self {
|
||||
Self {
|
||||
@@ -230,9 +279,8 @@ impl IntoPy<PyObject> for CalibrationTarget {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains CalibrationTarget from PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"resources" => Ok(CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
@@ -243,17 +291,31 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
lazy_static! {
|
||||
/// The version of the ezkl library
|
||||
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
|
||||
"source - no compatibility guaranteed"
|
||||
} else {
|
||||
env!("CARGO_PKG_VERSION")
|
||||
};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts ContractType into a PyObject (Required for ContractType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for ContractType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
|
||||
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for ContractType {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"verifier" => Ok(ContractType::Verifier { reusable: false }),
|
||||
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
|
||||
"vka" => Ok(ContractType::VerifyingKeyArtifact),
|
||||
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the styles for the CLI
|
||||
@@ -263,49 +325,49 @@ pub fn get_styles() -> clap::builder::Styles {
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Cyan,
|
||||
))),
|
||||
)
|
||||
.header(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
)
|
||||
.literal(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
|
||||
)
|
||||
.invalid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
)
|
||||
.error(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Cyan,
|
||||
))),
|
||||
)
|
||||
.literal(clap::builder::styling::Style::new().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta),
|
||||
)))
|
||||
.invalid(clap::builder::styling::Style::new().bold().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
|
||||
)))
|
||||
.error(clap::builder::styling::Style::new().bold().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
|
||||
)))
|
||||
.valid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
|
||||
)
|
||||
.placeholder(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Green,
|
||||
))),
|
||||
)
|
||||
.placeholder(clap::builder::styling::Style::new().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White),
|
||||
)))
|
||||
}
|
||||
|
||||
|
||||
/// Print completions for the given generator
|
||||
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
|
||||
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
|
||||
pub struct Cli {
|
||||
/// If provided, outputs the completion file for given shell
|
||||
#[clap(long = "generate", value_parser)]
|
||||
@@ -315,7 +377,6 @@ pub struct Cli {
|
||||
pub command: Option<Commands>,
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
@@ -363,9 +424,22 @@ pub enum Commands {
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
},
|
||||
|
||||
/// Generate random data for a model
|
||||
GenRandomData {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
variables: Vec<(String, usize)>,
|
||||
/// random seed for reproducibility (optional)
|
||||
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
|
||||
seed: u64,
|
||||
},
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
|
||||
@@ -397,9 +471,6 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
only_range_check_rebase: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -416,11 +487,10 @@ pub enum Commands {
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
|
||||
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
|
||||
@@ -467,7 +537,7 @@ pub enum Commands {
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -494,7 +564,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -502,7 +572,7 @@ pub enum Commands {
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum,
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
@@ -536,7 +606,7 @@ pub enum Commands {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
@@ -552,7 +622,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEvmData {
|
||||
@@ -577,7 +646,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
output_source: TestDataSource,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
|
||||
#[command(arg_required_else_help = true)]
|
||||
TestUpdateAccountCalls {
|
||||
@@ -591,7 +659,6 @@ pub enum Commands {
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
@@ -602,7 +669,6 @@ pub enum Commands {
|
||||
witness_path: Option<PathBuf>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Loads model, data, and creates proof
|
||||
Prove {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
@@ -617,7 +683,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -625,7 +691,7 @@ pub enum Commands {
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = ProofType::Single,
|
||||
value_enum,
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
proof_type: ProofType,
|
||||
@@ -633,7 +699,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Encodes a proof into evm calldata
|
||||
#[command(name = "encode-evm-calldata")]
|
||||
EncodeEvmCalldata {
|
||||
@@ -647,11 +712,10 @@ pub enum Commands {
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -666,17 +730,14 @@ pub enum Commands {
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEvmVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -692,7 +753,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEvmDataAttestation {
|
||||
@@ -717,11 +777,10 @@ pub enum Commands {
|
||||
witness: Option<PathBuf>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
@@ -739,11 +798,9 @@ pub enum Commands {
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
/// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
@@ -756,7 +813,7 @@ pub enum Commands {
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
@@ -774,7 +831,7 @@ pub enum Commands {
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -784,9 +841,8 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVerifier {
|
||||
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
|
||||
DeployEvm {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
@@ -802,27 +858,10 @@ pub enum Commands {
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
/// Contract type to be deployed
|
||||
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
|
||||
contract: ContractType,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
@@ -848,7 +887,6 @@ pub enum Commands {
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEvm {
|
||||
@@ -877,7 +915,6 @@ pub enum Commands {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
impl Commands {
|
||||
/// Converts the commands to a json string
|
||||
pub fn as_json(&self) -> String {
|
||||
@@ -888,4 +925,4 @@ impl Commands {
|
||||
pub fn from_json(json: &str) -> Self {
|
||||
serde_json::from_str(json).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
400
src/eth.rs
400
src/eth.rs
File diff suppressed because one or more lines are too long
318
src/execute.rs
318
src/execute.rs
@@ -1,18 +1,16 @@
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{
|
||||
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
|
||||
fix_da_single_sol,
|
||||
};
|
||||
#[allow(unused_imports)]
|
||||
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
|
||||
use crate::graph::input::GraphData;
|
||||
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
|
||||
use crate::graph::input::{Calls, GraphData};
|
||||
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::{TestDataSource, TestSources};
|
||||
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::pfsys::{
|
||||
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
|
||||
};
|
||||
@@ -21,11 +19,9 @@ use crate::pfsys::{
|
||||
};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{commands::*, EZKLError};
|
||||
use crate::{Commitments, RunArgs};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
@@ -45,17 +41,13 @@ use halo2_proofs::poly::kzg::{
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use halo2_solidity_verifier;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, WithSmallOrderMulGroup};
|
||||
use halo2curves::serde::SerdeObject;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use instant::Instant;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::debug;
|
||||
use log::{info, trace, warn};
|
||||
use serde::de::DeserializeOwned;
|
||||
@@ -65,9 +57,7 @@ use snark_verifier::system::halo2::compile;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use std::fs::File;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::{Cursor, Write};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
@@ -75,6 +65,8 @@ use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tabled::Tabled;
|
||||
use thiserror::Error;
|
||||
use tract_onnx::prelude::IntoTensor;
|
||||
use tract_onnx::prelude::Tensor as TractTensor;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -128,7 +120,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -145,7 +136,17 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
args,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GenRandomData {
|
||||
model,
|
||||
data,
|
||||
variables,
|
||||
seed,
|
||||
} => gen_random_data(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
variables,
|
||||
seed,
|
||||
),
|
||||
Commands::CalibrateSettings {
|
||||
model,
|
||||
settings_path,
|
||||
@@ -155,7 +156,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
@@ -164,7 +164,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -188,14 +187,13 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
witness.unwrap_or(DEFAULT_WITNESS.into()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifier {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
} => {
|
||||
create_evm_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
@@ -203,11 +201,10 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
reusable.unwrap_or(DEFAULT_RENDER_REUSABLE.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::EncodeEvmCalldata {
|
||||
proof_path,
|
||||
calldata_path,
|
||||
@@ -219,14 +216,14 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
Commands::CreateEvmVK {
|
||||
Commands::CreateEvmVKArtifact {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => {
|
||||
create_evm_vk(
|
||||
create_evm_vka(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
@@ -235,7 +232,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmDataAttestation {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -252,7 +248,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifierAggr {
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -260,7 +255,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
reusable,
|
||||
} => {
|
||||
create_evm_aggregate_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
@@ -269,7 +264,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
|
||||
aggregation_settings,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
reusable.unwrap_or(DEFAULT_RENDER_REUSABLE.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -298,7 +293,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
disable_selector_compression
|
||||
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
data,
|
||||
compiled_circuit,
|
||||
@@ -317,13 +311,11 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::TestUpdateAccountCalls {
|
||||
addr,
|
||||
data,
|
||||
rpc_url,
|
||||
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SwapProofCommitments {
|
||||
proof_path,
|
||||
witness_path,
|
||||
@@ -333,7 +325,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::Prove {
|
||||
witness,
|
||||
compiled_circuit,
|
||||
@@ -433,13 +424,13 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVerifier {
|
||||
Commands::DeployEvm {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
contract,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
@@ -447,29 +438,10 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
contract,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmVK {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
|
||||
rpc_url,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_VK.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmDataAttestation {
|
||||
data,
|
||||
settings_path,
|
||||
@@ -490,7 +462,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::VerifyEvm {
|
||||
proof_path,
|
||||
addr_verifier,
|
||||
@@ -610,7 +581,6 @@ pub(crate) fn gen_srs_cmd(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, EZKLError> {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
@@ -630,7 +600,6 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, EZKLError> {
|
||||
Ok(std::mem::take(&mut buf))
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, EZKLError> {
|
||||
use std::io::Read;
|
||||
let file = std::fs::File::open(path)?;
|
||||
@@ -649,7 +618,6 @@ pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, EZKLError> {
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -675,7 +643,6 @@ fn check_srs_hash(
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
@@ -718,20 +685,21 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
|
||||
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
|
||||
let mut file = std::fs::File::create(&computed_srs_path)?;
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
info!(
|
||||
"Saved SRS to {}.",
|
||||
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
|
||||
);
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -777,16 +745,16 @@ pub(crate) async fn gen_witness(
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
// if any of the settings have kzg visibility then we need to load the srs
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let region_settings = RegionSettings::all_true();
|
||||
let region_settings =
|
||||
RegionSettings::all_true(settings.run_args.decomp_base, settings.run_args.decomp_legs);
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
@@ -873,8 +841,72 @@ pub(crate) fn gen_circuit_settings(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
/// Generate a circuit settings file
|
||||
pub(crate) fn gen_random_data(
|
||||
model_path: PathBuf,
|
||||
data_path: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut file = std::fs::File::open(&model_path).map_err(|e| {
|
||||
crate::graph::errors::GraphError::ReadWriteFileError(
|
||||
model_path.display().to_string(),
|
||||
e.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
|
||||
|
||||
let input_facts = tract_model
|
||||
.input_outlets()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?
|
||||
.iter()
|
||||
.map(|&i| tract_model.outlet_fact(i))
|
||||
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?;
|
||||
|
||||
/// Generates a random tensor of a given size and type.
|
||||
fn random(
|
||||
sizes: &[usize],
|
||||
datum_type: tract_onnx::prelude::DatumType,
|
||||
seed: u64,
|
||||
) -> TractTensor {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
|
||||
let slice = tensor.as_slice_mut::<f32>().unwrap();
|
||||
slice.iter_mut().for_each(|x| *x = rng.gen());
|
||||
tensor.cast_to_dt(datum_type).unwrap().into_owned()
|
||||
}
|
||||
|
||||
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
|
||||
if let Some(value) = &fact.konst {
|
||||
return value.clone().into_tensor();
|
||||
}
|
||||
|
||||
random(
|
||||
fact.shape
|
||||
.as_concrete()
|
||||
.expect("Expected concrete shape, found: {fact:?}"),
|
||||
fact.datum_type,
|
||||
seed,
|
||||
)
|
||||
}
|
||||
|
||||
let generated = input_facts
|
||||
.iter()
|
||||
.map(|v| tensor_for_fact(v, seed))
|
||||
.collect_vec();
|
||||
|
||||
let data = GraphData::from_tract_data(&generated)?;
|
||||
|
||||
data.save(data_path)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
// not for wasm targets
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn init_spinner() -> ProgressBar {
|
||||
let pb = indicatif::ProgressBar::new_spinner();
|
||||
pb.set_draw_target(indicatif::ProgressDrawTarget::stdout());
|
||||
@@ -896,7 +928,6 @@ pub(crate) fn init_spinner() -> ProgressBar {
|
||||
}
|
||||
|
||||
// not for wasm targets
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn init_bar(len: u64) -> ProgressBar {
|
||||
let pb = ProgressBar::new(len);
|
||||
pb.set_draw_target(indicatif::ProgressDrawTarget::stdout());
|
||||
@@ -910,7 +941,6 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar {
|
||||
pb
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
|
||||
#[derive(Debug, Clone, Tabled)]
|
||||
@@ -1010,7 +1040,6 @@ impl AccuracyResults {
|
||||
}
|
||||
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn calibrate(
|
||||
@@ -1021,7 +1050,6 @@ pub(crate) async fn calibrate(
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
@@ -1057,12 +1085,6 @@ pub(crate) async fn calibrate(
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
@@ -1100,12 +1122,6 @@ pub(crate) async fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
@@ -1114,41 +1130,34 @@ pub(crate) async fn calibrate(
|
||||
let mut num_failed = 0;
|
||||
let mut num_passed = 0;
|
||||
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
|
||||
input_scale.to_string().blue(),
|
||||
param_scale.to_string().blue(),
|
||||
scale_rebase_multiplier.to_string().blue(),
|
||||
div_rebasing.to_string().yellow(),
|
||||
scale_rebase_multiplier.to_string().yellow(),
|
||||
num_failed.to_string().red(),
|
||||
num_passed.to_string().green()
|
||||
));
|
||||
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
// if unix get a gag
|
||||
#[cfg(unix)]
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _g = match Gag::stderr() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
@@ -1178,7 +1187,10 @@ pub(crate) async fn calibrate(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(),
|
||||
RegionSettings::all_true(
|
||||
settings.run_args.decomp_base,
|
||||
settings.run_args.decomp_legs,
|
||||
),
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
@@ -1204,9 +1216,9 @@ pub(crate) async fn calibrate(
|
||||
}
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
drop(_r);
|
||||
#[cfg(unix)]
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
drop(_g);
|
||||
|
||||
let result = forward_pass_res.get(&key).ok_or("key not found")?;
|
||||
@@ -1238,7 +1250,6 @@ pub(crate) async fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -1254,6 +1265,7 @@ pub(crate) async fn calibrate(
|
||||
num_rows: new_settings.num_rows,
|
||||
total_assignments: new_settings.total_assignments,
|
||||
total_const_size: new_settings.total_const_size,
|
||||
total_dynamic_col_size: new_settings.total_dynamic_col_size,
|
||||
..settings.clone()
|
||||
};
|
||||
|
||||
@@ -1345,7 +1357,6 @@ pub(crate) async fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1371,9 +1382,13 @@ pub(crate) async fn calibrate(
|
||||
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
|
||||
let module_log_row = best_params.module_constraint_logrows_with_blinding();
|
||||
let instance_logrows = best_params.log2_total_instances_with_blinding();
|
||||
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
let dynamic_lookup_logrows =
|
||||
best_params.min_dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
|
||||
let range_check_logrows = best_params.range_check_log_rows_with_blinding();
|
||||
|
||||
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
|
||||
reduction = std::cmp::max(reduction, range_check_logrows);
|
||||
reduction = std::cmp::max(reduction, instance_logrows);
|
||||
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
|
||||
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
|
||||
@@ -1419,14 +1434,13 @@ pub(crate) fn mock(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_verifier(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
@@ -1448,24 +1462,23 @@ pub(crate) async fn create_evm_verifier(
|
||||
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
|
||||
num_instance,
|
||||
);
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
let (verifier_solidity, name) = if reusable {
|
||||
(generator.render_separately()?.0, "Halo2VerifierReusable") // ignore the rendered vk artifact for now and generate it in create_evm_vka
|
||||
} else {
|
||||
generator.render()?
|
||||
(generator.render()?, "Halo2Verifier")
|
||||
};
|
||||
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, name, 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_vk(
|
||||
pub(crate) async fn create_evm_vka(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
@@ -1498,14 +1511,13 @@ pub(crate) async fn create_evm_vk(
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
@@ -1525,13 +1537,24 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
let mut output_len = None;
|
||||
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
return Err("private output data on chain is not supported on chain".into());
|
||||
}
|
||||
let mut on_chain_output_data = vec![];
|
||||
for call in source.calls {
|
||||
on_chain_output_data.push(call);
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_output_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
output_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_output_data)
|
||||
} else {
|
||||
@@ -1543,8 +1566,15 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
return Err("private input data on chain is not supported on chain".into());
|
||||
}
|
||||
let mut on_chain_input_data = vec![];
|
||||
for call in source.calls {
|
||||
on_chain_input_data.push(call);
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_input_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
input_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_input_data)
|
||||
} else {
|
||||
@@ -1571,18 +1601,28 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
None
|
||||
};
|
||||
|
||||
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
// if either input_len or output_len is Some then we are in the single call data attestation mode
|
||||
if input_len.is_some() || output_len.is_some() {
|
||||
let output = fix_da_single_sol(input_len, output_len)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
} else {
|
||||
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
}
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn deploy_da_evm(
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1609,15 +1649,19 @@ pub(crate) async fn deploy_da_evm(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn deploy_evm(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
contract: ContractType,
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_name = match contract {
|
||||
ContractType::Verifier { reusable: false } => "Halo2Verifier",
|
||||
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
|
||||
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
|
||||
};
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
@@ -1660,7 +1704,6 @@ pub(crate) fn encode_evm_calldata(
|
||||
Ok(encoded)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
addr_verifier: H160Flag,
|
||||
@@ -1700,7 +1743,6 @@ pub(crate) async fn verify_evm(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_aggregate_verifier(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -1708,7 +1750,7 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
abi_path: PathBuf,
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
reusable: bool,
|
||||
) -> Result<String, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
|
||||
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
|
||||
@@ -1746,8 +1788,8 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
|
||||
generator = generator.set_acc_encoding(Some(acc_encoding));
|
||||
|
||||
let verifier_solidity = if render_vk_seperately {
|
||||
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
|
||||
let verifier_solidity = if reusable {
|
||||
generator.render_separately()?.0 // ignore the rendered vk artifact for now and generate it in create_evm_vka
|
||||
} else {
|
||||
generator.render()?
|
||||
};
|
||||
@@ -1824,7 +1866,6 @@ pub(crate) fn setup(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn setup_test_evm_witness(
|
||||
data_path: PathBuf,
|
||||
compiled_circuit_path: PathBuf,
|
||||
@@ -1860,9 +1901,7 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::pfsys::ProofType;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
@@ -1875,7 +1914,6 @@ pub(crate) async fn test_update_account_calls(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn prove(
|
||||
data_path: PathBuf,
|
||||
@@ -2073,7 +2111,6 @@ pub(crate) fn mock_aggregate(
|
||||
}
|
||||
}
|
||||
// proof aggregation
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
@@ -2085,7 +2122,6 @@ pub(crate) fn mock_aggregate(
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
}
|
||||
@@ -2180,7 +2216,6 @@ pub(crate) fn aggregate(
|
||||
}
|
||||
|
||||
// proof aggregation
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
@@ -2330,7 +2365,6 @@ pub(crate) fn aggregate(
|
||||
);
|
||||
snark.save(&proof_path)?;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("Done.");
|
||||
|
||||
Ok(snark)
|
||||
|
||||
@@ -5,7 +5,7 @@ use halo2curves::ff::PrimeField;
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
/// Converts an integer rep to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
@@ -69,7 +69,7 @@ mod test {
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -48,7 +54,10 @@ pub enum GraphError {
|
||||
#[error("failed to ser/deser model: {0}")]
|
||||
ModelSerialize(#[from] bincode::Error),
|
||||
/// Tract error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
#[error("[tract] {0}")]
|
||||
TractError(#[from] tract_onnx::prelude::TractError),
|
||||
/// Packing exponent is too large
|
||||
@@ -85,11 +94,17 @@ pub enum GraphError {
|
||||
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
|
||||
MissingBatchSize,
|
||||
/// Tokio postgres error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
#[error("[tokio postgres] {0}")]
|
||||
TokioPostgresError(#[from] tokio_postgres::Error),
|
||||
/// Eth error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
#[error("[eth] {0}")]
|
||||
EthError(#[from] crate::eth::EthError),
|
||||
/// Json error
|
||||
@@ -134,4 +149,7 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -14,18 +14,17 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
type Decimals = u8;
|
||||
@@ -128,7 +127,9 @@ impl FileSourceInner {
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
@@ -155,23 +156,84 @@ impl FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
/// Call type for attested inputs on-chain
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Vector of calls to accounts, each returning an attested data point
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call to account, returning an array of attested data points
|
||||
Single(CallToAccount),
|
||||
}
|
||||
|
||||
impl Default for Calls {
|
||||
fn default() -> Self {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Vector of calls to accounts
|
||||
pub calls: Vec<CallsToAccount>,
|
||||
/// Calls to accounts
|
||||
pub calls: Calls,
|
||||
/// RPC url
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Create a new OnChainSource
|
||||
pub fn new(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource { calls, rpc }
|
||||
/// Create a new OnChainSource with multiple calls
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OnChainSource with a single call
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
Calls::Single(data) => data.serialize(serializer),
|
||||
Calls::Multiple(data) => data.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for Calls {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
}
|
||||
let single_try: Result<CallToAccount, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = single_try {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(
|
||||
"failed to deserialize FileSourceInner",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Inner elements of inputs/outputs coming from postgres DB
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
@@ -189,7 +251,7 @@ pub struct PostgresSource {
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Create a new PostgresSource
|
||||
pub fn new(
|
||||
@@ -268,7 +330,7 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSource,
|
||||
@@ -277,7 +339,8 @@ impl OnChainSource {
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
|
||||
use crate::eth::{
|
||||
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
|
||||
evm_quantize_multi, read_on_chain_inputs_multi, test_on_chain_data,
|
||||
DEFAULT_ANVIL_ENDPOINT,
|
||||
};
|
||||
use log::debug;
|
||||
|
||||
@@ -296,7 +359,7 @@ impl OnChainSource {
|
||||
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
|
||||
debug!("Calls to accounts: {:?}", calls_to_accounts);
|
||||
let inputs =
|
||||
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
|
||||
read_on_chain_inputs_multi(client.clone(), client_address, &calls_to_accounts).await?;
|
||||
debug!("Inputs: {:?}", inputs);
|
||||
|
||||
let mut quantized_evm_inputs = vec![];
|
||||
@@ -304,7 +367,7 @@ impl OnChainSource {
|
||||
let mut prev = 0;
|
||||
for (idx, i) in data.iter().enumerate() {
|
||||
quantized_evm_inputs.extend(
|
||||
evm_quantize(
|
||||
evm_quantize_multi(
|
||||
client.clone(),
|
||||
vec![scales[idx]; i.len()],
|
||||
&(
|
||||
@@ -330,7 +393,7 @@ impl OnChainSource {
|
||||
// Fill the input_data field of the GraphData struct
|
||||
Ok((
|
||||
inputs,
|
||||
OnChainSource::new(calls_to_accounts.clone(), used_rpc),
|
||||
OnChainSource::new_multiple(calls_to_accounts.clone(), used_rpc),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -350,6 +413,24 @@ pub struct CallsToAccount {
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Defines a view only call to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// The call_data is a byte strings representing the ABI encoded function call to
|
||||
/// read the data from the address. This call must return a single array of integers that can be
|
||||
/// be safely cast to the int128 type in solidity.
|
||||
pub call_data: Call,
|
||||
/// The number of decimals for f32 conversion of all of the elements returned from the
|
||||
/// call.
|
||||
pub decimals: Decimals,
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
/// The number of elements returned from the call.
|
||||
pub len: usize,
|
||||
}
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
@@ -359,7 +440,7 @@ pub enum DataSource {
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
OnChain(OnChainSource),
|
||||
/// Postgres DB
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
|
||||
@@ -419,7 +500,7 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = third_try {
|
||||
@@ -433,7 +514,7 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
pub input_data: DataSource,
|
||||
@@ -445,7 +526,7 @@ impl UnwindSafe for GraphData {}
|
||||
|
||||
impl GraphData {
|
||||
// not wasm
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the input data to tract data
|
||||
pub fn to_tract_data(
|
||||
&self,
|
||||
@@ -475,6 +556,34 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the tract data to tract data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
let mut input_data = vec![];
|
||||
for tensor in tensors {
|
||||
match tensor.datum_type() {
|
||||
tract_onnx::prelude::DatumType::Bool => {
|
||||
let tensor = tensor.to_array_view::<bool>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Bool(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
_ => {
|
||||
let cast_tensor = tensor.cast_to_dt(DatumType::F64)?;
|
||||
let tensor = cast_tensor.to_array_view::<f64>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Float(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource::File(input_data),
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
@@ -530,7 +639,7 @@ impl GraphData {
|
||||
"on-chain data cannot be split into batches".to_string(),
|
||||
))
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
GraphData {
|
||||
input_data: DataSource::DB(data),
|
||||
output_data: _,
|
||||
@@ -600,6 +709,28 @@ impl ToPyObject for CallsToAccount {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
dict.set_item("call_data", &self.call_data).unwrap();
|
||||
dict.set_item("decimals", &self.decimals).unwrap();
|
||||
dict.set_item("len", &self.len).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for Calls {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Calls::Multiple(calls) => calls.to_object(py),
|
||||
Calls::Single(call) => call.to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for DataSource {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -608,7 +739,8 @@ impl ToPyObject for DataSource {
|
||||
DataSource::OnChain(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("rpc_url", &source.rpc).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
DataSource::DB(source) => {
|
||||
@@ -636,18 +768,6 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GraphData {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphData", 4)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.serialize_field("output_data", &self.output_data)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user