Compare commits

..

57 Commits

Author SHA1 Message Date
Ethan
41d3233223 h2 sol verifier package update. 2024-10-28 20:01:00 +08:00
Ethan
cda4b2dacd merge "main" into "verifier-router" 2024-10-28 19:30:03 +08:00
Ethan
1ddbce537f ci testing 2024-10-25 22:31:38 +08:00
Ethan Cemer
5faa56f359 Merge branch 'zkonduit:main' into verifier-router 2024-10-11 02:15:14 -05:00
Ethan
97c1397fe6 verifier manager contract. 2024-10-08 20:47:58 +08:00
Ethan
e3051b79cd *bit flip fuzzing tests 2024-10-04 17:24:57 +08:00
Ethan Cemer
fe960767df Merge branch 'main' into reusable-vk-nb 2024-09-19 20:28:15 -05:00
Ethan
d4ebf8f120 skip lstm_large in overflow b/c too big to fit on chain. 2024-09-10 13:48:36 -05:00
Ethan
6788a9c726 *comprehensive test coverage for reusable verifier 2024-09-09 17:00:42 -05:00
Ethan
9fd3799ed1 *expand reusable verifier test examples 2024-09-07 17:46:32 -05:00
Ethan
c9351c0607 *reduce extcodeopy call by 1 2024-09-05 20:10:16 -05:00
Ethan Cemer
a2a8488d67 Merge branch 'main' into reusable-vk-nb 2024-09-05 15:40:34 -05:00
Ethan
9908f5c214 *add col overflow testing for reusable verifier. 2024-09-04 21:45:29 -05:00
Ethan Cemer
66ae2f7bac Merge branch 'main' into reusable-vk-nb 2024-09-02 18:02:09 -05:00
Ethan
4d18a5a903 *rmv create_evm_vk cmd
*test reusable verifier after h2 curve updates
2024-09-02 18:01:40 -05:00
Ethan
cd1860246c main lock 2024-08-29 16:43:16 -04:00
Ethan
ad59186eb6 Merge branch 'main' into reusable-vk-nb 2024-08-29 16:41:16 -04:00
Ethan
472a505063 *update separate vk contract name 2024-08-14 13:07:37 -04:00
Ethan
1588c9735e *update lock 2024-08-13 20:22:52 -04:00
Ethan
9d876d5bf9 MV lookup packed. 2024-08-09 17:34:22 -04:00
Ethan
fd29794cdd hardcode coeff_ptr 2024-08-08 13:55:10 -04:00
Ethan
a616fbb759 packed permutation evals and challenge data. 2024-08-07 16:34:39 -04:00
Ethan
f1fe01952f Merge branch 'main' into reusable-vk-nb 2024-08-05 10:39:39 -04:00
Ethan
23f71873ce *update lock 2024-08-05 10:37:06 -04:00
Ethan
31168b0a99 *fully reusable verifier 2024-08-03 01:10:26 -05:00
Ethan
35bb286d40 *coeff_sums_computation 2024-08-01 21:11:47 -05:00
Ethan
d5944d36fe *r_evals_computation 2024-08-01 15:18:17 -05:00
Ethan
3df63d540d coeff_computations. 2024-07-30 19:18:23 -05:00
Ethan
779f82e0dc *vanish computations pcs 2024-07-29 20:35:10 -05:00
Ethan
889db3a6fe *update lock. 2024-07-29 18:16:33 -05:00
Ethan
fab08bbe9d *update cargo lock 2024-07-29 14:08:16 -05:00
Ethan
72f1892d54 *update lock 2024-07-29 06:18:31 -05:00
Ethan
f3e531c3e7 *update lock 2024-07-28 21:52:29 -05:00
Ethan
00f8dd42f6 *update lock
*revert to svm 0.8.20
2024-07-28 18:06:56 -05:00
Ethan
099726b245 update lock. 2024-07-26 22:14:41 -05:00
Ethan
d5f18495de *update lock. 2024-07-26 15:10:41 -05:00
Ethan
add04f6090 Merge branch 'main' into reusable-vk-nb 2024-07-26 15:10:03 -05:00
Ethan
6b71bdc920 use latest version of solc 2024-07-23 15:01:42 -05:00
Ethan
3e5153db9f *update lock 2024-07-23 14:59:49 -05:00
Ethan
a1dd82a3c1 *update lock 2024-07-19 18:20:26 -05:00
Ethan
f6acf241c9 Merge branch 'main' into reusable-vk-nb 2024-07-19 18:18:49 -05:00
Ethan
dbe812b88d *update lock 2024-07-12 22:14:13 -05:00
Ethan
36188ab542 *comment out JS tests for reusable verifier CI tests 2024-07-12 16:00:46 -05:00
Ethan Cemer
cd9d7c3d50 Merge branch 'main' into reusable-vk-nb 2024-07-12 15:58:16 -05:00
Ethan Cemer
f5ae49e1c5 Merge branch 'main' into reusable-vk-nb 2024-07-11 22:48:11 -05:00
Ethan
f25b420429 *update lock 2024-07-11 22:47:42 -05:00
Ethan
f59aaf80c5 *update lock 2024-07-11 15:24:28 -05:00
Ethan
257e275773 *update lock 2024-07-11 00:23:19 -05:00
Ethan
2fe0eb4b27 *update lock 2024-07-06 23:05:48 -05:00
Ethan
bdad19b83c *update lock 2024-07-03 21:00:58 -05:00
Ethan
a17aad064b *update lock 2024-07-02 21:43:38 -05:00
Ethan
985205ae40 *update lock
*hardcode sampel inputs for resuable verifiers nb
2024-07-02 00:49:34 -05:00
Ethan Cemer
b08dc28ed5 Merge branch 'main' into reusable-vk-nb 2024-07-01 17:08:23 -05:00
Ethan
b3997cd325 lazy static import 2024-06-28 09:36:18 -05:00
Ethan
83cb957299 *fix stuck integration tests. 2024-06-27 19:46:05 -05:00
Ethan
c92be15b81 *update lockfile 2024-06-26 16:11:06 -05:00
Ethan
6924797e48 *reusable verifier example nb 2024-06-25 16:10:24 -05:00
119 changed files with 5124 additions and 11660 deletions

View File

@@ -2,16 +2,3 @@
runner = 'wasm-bindgen-test-runner'
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
"link-arg=--max-memory=4294967296"]
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

View File

@@ -6,15 +6,22 @@ on:
description: "Test scenario tags"
jobs:
bench_poseidon:
permissions:
contents: read
bench_elgamal:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
persist-credentials: false
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Bench elgamal
run: cargo bench --verbose --bench elgamal
bench_poseidon:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -24,14 +31,10 @@ jobs:
run: cargo bench --verbose --bench poseidon
bench_einsum_accum_matmul:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -41,14 +44,10 @@ jobs:
run: cargo bench --verbose --bench accum_einsum_matmul
bench_accum_matmul_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -58,14 +57,10 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu
bench_accum_matmul_relu_overflow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -75,14 +70,10 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu_overflow
bench_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -92,14 +83,10 @@ jobs:
run: cargo bench --verbose --bench relu
bench_accum_dot:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -109,14 +96,10 @@ jobs:
run: cargo bench --verbose --bench accum_dot
bench_accum_conv:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -126,14 +109,10 @@ jobs:
run: cargo bench --verbose --bench accum_conv
bench_accum_sumpool:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -143,14 +122,10 @@ jobs:
run: cargo bench --verbose --bench accum_sumpool
bench_pairwise_add:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -160,14 +135,10 @@ jobs:
run: cargo bench --verbose --bench pairwise_add
bench_accum_sum:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -177,14 +148,10 @@ jobs:
run: cargo bench --verbose --bench accum_sum
bench_pairwise_pow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27

View File

@@ -15,18 +15,11 @@ defaults:
working-directory: .
jobs:
publish-wasm-bindings:
permissions:
contents: read
packages: write
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -49,41 +42,41 @@ jobs:
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
echo '{
"name": "@ezkljs/engine",
"version": "${{ github.ref_name }}",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}' > pkg/package.json
- name: Replace memory definition in nodejs
run: |
@@ -191,26 +184,21 @@ jobs:
in-browser-evm-ver-publish:
permissions:
contents: read
packages: write
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
env:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)

View File

@@ -6,13 +6,9 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
permissions:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18

View File

@@ -18,19 +18,12 @@ defaults:
jobs:
linux:
permissions:
contents: read
packages: write
runs-on: GPU
strategy:
matrix:
target: [x86_64]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
@@ -41,10 +34,6 @@ jobs:
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: rename ezkl.pyi to ezkl-gpu.pyi
run: mv ezkl.pyi ezkl-gpu.pyi
- uses: actions-rs/toolchain@v1
with:
@@ -54,6 +43,8 @@ jobs:
- name: Set Cargo.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
@@ -79,7 +70,7 @@ jobs:
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
@@ -107,14 +98,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./

View File

@@ -16,32 +16,22 @@ defaults:
jobs:
macos:
permissions:
contents: read
runs-on: macos-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64, universal2-apple-darwin]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
@@ -55,13 +45,6 @@ jobs:
components: rustfmt, clippy
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
@@ -73,14 +56,12 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: dist-macos-${{ matrix.target }}
name: wheels
path: dist
windows:
permissions:
contents: read
runs-on: windows-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -88,21 +69,11 @@ jobs:
target: [x64, x86]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -130,14 +101,12 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: dist-windows-${{ matrix.target }}
name: wheels
path: dist
linux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -145,21 +114,11 @@ jobs:
target: [x86_64]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -170,6 +129,7 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -203,14 +163,63 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: dist-linux-${{ matrix.target }}
name: wheels
path: dist
# TODO: There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# linux-cross:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# target: [aarch64, armv7]
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
# - name: Install cross-compilation tools for armv7
# if: matrix.target == 'armv7'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
# - name: Build wheels
# uses: PyO3/maturin-action@v1
# with:
# target: ${{ matrix.target }}
# manylinux: auto
# args: --release --out dist --features python-bindings
# - uses: uraimo/run-on-arch-action@v2.5.0
# name: Install built wheel
# with:
# arch: ${{ matrix.target }}
# distro: ubuntu20.04
# githubToken: ${{ github.token }}
# install: |
# apt-get update
# apt-get install -y --no-install-recommends python3 python3-pip
# pip3 install -U pip
# run: |
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
# python3 -c "import ezkl"
# - name: Upload wheels
# uses: actions/upload-artifact@v3
# with:
# name: wheels
# path: dist
musllinux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -219,21 +228,11 @@ jobs:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -271,14 +270,12 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: dist-musllinux-${{ matrix.target }}
name: wheels
path: dist
musllinux-cross:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -286,22 +283,14 @@ jobs:
platform:
- target: aarch64-unknown-linux-musl
arch: aarch64
- target: armv7-unknown-linux-musleabihf
arch: armv7
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -319,7 +308,7 @@ jobs:
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@v2.8.1
- uses: uraimo/run-on-arch-action@v2.5.0
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}
@@ -334,9 +323,9 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: dist-musllinux-${{ matrix.platform.target }}
name: wheels
path: dist
pypi-publish:
@@ -345,43 +334,44 @@ jobs:
permissions:
id-token: write
if: "startsWith(github.ref, 'refs/tags/')"
# TODO: Uncomment if linux-cross is working
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@v3
with:
pattern: dist-*
merge-multiple: true
path: dist
name: wheels
- name: List Files
run: ls -R
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@unstable/v1
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# Both publish steps will fail if there is no trusted publisher setup
# On failure the publish step will then simply continue to the next one
# publishes to PyPI
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@unstable/v1
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
permissions:
contents: read
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}
commit_ref: ${{ github.ref_name }}

View File

@@ -10,9 +10,6 @@ on:
- "*"
jobs:
create-release:
permissions:
contents: read
packages: write
name: create-release
runs-on: ubuntu-22.04
if: startsWith(github.ref, 'refs/tags/')
@@ -36,9 +33,6 @@ jobs:
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
permissions:
contents: read
packages: write
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -56,9 +50,6 @@ jobs:
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -100,10 +91,6 @@ jobs:
asset_content_type: application/octet-stream
build-release:
permissions:
contents: read
packages: write
issues: write
name: build-release
needs: ["create-release"]
runs-on: ${{ matrix.os }}
@@ -145,8 +132,6 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -196,18 +181,14 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
- name: Build release binary (no asm)
if: matrix.build != 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"

View File

@@ -19,35 +19,11 @@ env:
CARGO_TERM_COLOR: always
jobs:
fr-age-test:
needs: [build, library-tests, docs, python-tests, python-integration-tests]
permissions:
contents: read
runs-on: large-self-hosted
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: fr age Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
build:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -57,13 +33,9 @@ jobs:
run: cargo build --verbose
docs:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -73,13 +45,9 @@ jobs:
run: cargo doc --verbose
library-tests:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -103,8 +71,6 @@ jobs:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
@@ -135,13 +101,9 @@ jobs:
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
ultra-overflow-tests_og-lookup:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -163,22 +125,18 @@ jobs:
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
ultra-overflow-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -200,22 +158,18 @@ jobs:
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
run: cargo nextest run --release lookup_ultra_overflow --no-capture -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture -- --include-ignored
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
model-serialization:
permissions:
contents: read
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -229,13 +183,9 @@ jobs:
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
wasm32-tests:
permissions:
contents: read
runs-on: non-gpu
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -244,7 +194,7 @@ jobs:
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
version: 'v0.12.1'
- uses: nanasess/setup-chromedriver@v2
# with:
# chromedriver-version: "115.0.5790.102"
@@ -257,14 +207,28 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
mock-proving-tests:
permissions:
contents: read
runs-on: non-gpu
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
persist-credentials: false
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Circuit Render
run: cargo nextest run --release --verbose tests::tutorial_
mock-proving-tests:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -275,63 +239,57 @@ jobs:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
- name: Self Attention Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
- name: Multihead Attention Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
- name: public outputs
run: cargo nextest run --verbose tests::mock_public_outputs_ --test-threads 32
- name: public inputs
run: cargo nextest run --verbose tests::mock_public_inputs_ --test-threads 32
- name: fixed params
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
- name: public outputs and bounded lookup log
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and tolerance > 0
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --verbose tests::mock_kzg_input_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params
run: cargo nextest run --verbose tests::mock_kzg_params_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_kzg_params_::t --test-threads 32
- name: kzg outputs
run: cargo nextest run --verbose tests::mock_kzg_output_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_kzg_output_::t --test-threads 32
- name: kzg inputs + params + outputs
run: cargo nextest run --verbose tests::mock_kzg_all_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_kzg_all_::t --test-threads 32
- name: Mock fixed inputs
run: cargo nextest run --verbose tests::mock_fixed_inputs_ --test-threads 32
run: cargo nextest run --release --verbose tests::mock_fixed_inputs_ --test-threads 32
- name: Mock fixed outputs
run: cargo nextest run --verbose tests::mock_fixed_outputs --test-threads 32
run: cargo nextest run --release --verbose tests::mock_fixed_outputs --test-threads 32
- name: Mock accuracy calibration
run: cargo nextest run --verbose tests::mock_accuracy_cal_tests::a
run: cargo nextest run --release --verbose tests::mock_accuracy_cal_tests::a
- name: hashed inputs
run: cargo nextest run --verbose tests::mock_hashed_input_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --verbose tests::mock_hashed_params_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --verbose tests::mock_hashed_output_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
run: cargo nextest run --verbose tests::mock_hashed_all_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_all_::t --test-threads 32
- name: hashed inputs + fixed params
run: cargo nextest run --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
run: cargo nextest run --release --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
- name: MNIST Gan Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
- name: Self Attention Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
- name: Multihead Attention Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
- name: public outputs
run: cargo nextest run --release --verbose tests::mock_public_outputs_ --test-threads 32
- name: public inputs
run: cargo nextest run --release --verbose tests::mock_public_inputs_ --test-threads 32
- name: fixed params
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
prove-and-verify-evm-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -342,8 +300,6 @@ jobs:
crate: cargo-nextest
locked: true
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -364,7 +320,7 @@ jobs:
NODE_ENV: development
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
@@ -378,79 +334,41 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg params)
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed inputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed params)
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed outputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
# prove-and-verify-tests-metal:
# permissions:
# contents: read
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@v0.4.0
# with:
# # Pin to version 0.12.1
# version: 'v0.12.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18
# - uses: actions/checkout@v3
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@v2
# with:
# version: 8
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
prove-and-verify-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -459,15 +377,13 @@ jobs:
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -489,40 +405,40 @@ jobs:
locked: true
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (hashed inputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
- name: IPA prove and verify tests
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests single inner col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
- name: KZG prove and verify tests triple inner col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_triple_col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_triple_col
- name: KZG prove and verify tests quadruple inner col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_quadruple_col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_quadruple_col
- name: KZG prove and verify tests octuple inner col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
# prove-and-verify-tests-gpu:
# runs-on: GPU
@@ -530,8 +446,6 @@ jobs:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
@@ -545,31 +459,27 @@ jobs:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (kzg outputs)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public inputs)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# - name: KZG prove and verify tests (fixed params)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# - name: KZG prove and verify tests (hashed outputs)
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
prove-and-verify-mock-aggr-tests:
permissions:
contents: read
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -580,37 +490,31 @@ jobs:
crate: cargo-nextest
locked: true
- name: Mock aggr tests (KZG)
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
# prove-and-verify-aggr-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG tests
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -621,17 +525,13 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -646,17 +546,13 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
examples:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -670,14 +566,10 @@ jobs:
run: cargo nextest run --release tests_examples
python-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
@@ -695,19 +587,15 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
accuracy-measurement-tests:
permissions:
contents: read
runs-on: non-gpu
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
@@ -723,19 +611,19 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_inputs_
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_fixed_params_
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_fixed_params_
- name: Public outputs
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_outputs_
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_outputs_
- name: Public outputs + resources
run: source .env/bin/activate; cargo nextest run --verbose tests::resources_accuracy_measurement_public_outputs_
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
python-integration-tests:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
@@ -757,8 +645,6 @@ jobs:
- 5432:5432
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.11"
@@ -780,15 +666,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: Neural bow
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
@@ -802,87 +680,74 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Reusable verifier tutorial
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
ios-integration-tests:
permissions:
contents: read
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
permissions:
contents: read
runs-on: macos-latest
needs: [ios-integration-tests]
runs-on: macos-latest
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

View File

@@ -1,33 +0,0 @@
name: Static Analysis
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
analyze:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
# Run Zizmor static analysis
- name: Install Zizmor
run: cargo install --locked zizmor
- name: Run Zizmor Analysis
run: zizmor .

View File

@@ -1,134 +0,0 @@
name: Build and Publish EZKL iOS SPM package
on:
push:
tags:
# Only support SemVer versioning tags
- 'v[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
build-and-update:
permissions:
contents: read
packages: write
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
RELEASE_TAG: ${{ github.ref_name }}
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${RELEASE_TAG}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
echo "Stripped TAG: $NEW_TAG"
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Check for changes
id: check_changes
run: |
cd ezkl-swift-package
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
echo "no_changes=true" >> $GITHUB_OUTPUT
else
echo "no_changes=false" >> $GITHUB_OUTPUT
fi
- name: Set up Xcode environment
if: steps.check_changes.outputs.no_changes == 'false'
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Setup Git
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
env:
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
- name: Commit and Push Changes
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
fi
- name: Tag the latest commit
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."
else
git tag "$TAG"
fi
if ! git push origin "$TAG"; then
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
exit 1
fi

View File

@@ -12,8 +12,6 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2

View File

@@ -0,0 +1,75 @@
name: Build and Publish EZKL iOS SPM package
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
jobs:
build-and-update:
runs-on: macos-latest
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Commit and Push Changes to feat/ezkl-direct-integration
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git add Sources/EzklCoreBindings
git commit -m "Automatically updated EzklCoreBindings for EZKL"
git tag ${{ github.event.inputs.tag }}
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
git push origin
git push origin --tags
env:
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}

4
.gitignore vendored
View File

@@ -9,6 +9,7 @@ pkg
!AttestData.sol
!VerifierBase.sol
!LoadInstances.sol
!VerifierManager.sol
*.pf
*.vk
*.pk
@@ -27,6 +28,7 @@ __pycache__/
*.pyc
*.pyo
*.py[cod]
bin/
build/
develop-eggs/
dist/
@@ -48,4 +50,4 @@ timingData.json
!tests/assets/pk.key
!tests/assets/vk.key
docs/python/build
!tests/assets/vk_aggr.key
!tests/assets/vk_aggr.key

839
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -16,11 +16,12 @@ crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
"derive_serde",
] }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = [
"circuit-params",
] }
rand = { version = "0.8", default-features = false }
@@ -35,11 +36,12 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "vka-log", optional = true }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = { version = "1.6.0", optional = true }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
@@ -73,25 +75,26 @@ tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
regex = { version = "1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
@@ -144,10 +147,6 @@ shellexpand = "3.1.0"
runner = 'wasm-bindgen-test-runner'
[[bench]]
name = "zero_finder"
harness = false
[[bench]]
name = "accum_dot"
harness = false
@@ -186,7 +185,7 @@ harness = false
[[bench]]
name = "sigmoid"
name = "relu"
harness = false
[[bench]]
@@ -194,12 +193,12 @@ name = "relu_lookupless"
harness = false
[[bench]]
name = "accum_matmul_sigmoid"
name = "accum_matmul_relu"
harness = false
[[bench]]
name = "accum_matmul_sigmoid_overflow"
name = "accum_matmul_relu_overflow"
harness = false
[[bin]]
@@ -212,10 +211,6 @@ required-features = ["ezkl"]
name = "ios_gen_bindings"
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
[[bin]]
name = "py_stub_gen"
required-features = ["python-bindings"]
[features]
web = ["wasm-bindgen-rayon"]
default = [
@@ -226,7 +221,7 @@ default = [
"parallel-poly-read",
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
@@ -242,14 +237,16 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:openssl",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:regex",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:portable-atomic",
"dep:clap_complete",
"dep:halo2_solidity_verifier",
"dep:semver",
@@ -272,15 +269,13 @@ icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
@@ -289,16 +284,4 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
#panic = "abort"
[profile.test-runs]
inherits = "dev"
opt-level = 3
[package.metadata.wasm-pack.profile.release]
wasm-opt = [
"-O4",
"--flexible-inline-max-function-size",
"4294967295",
]
# panic = "abort"

View File

@@ -150,13 +150,6 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
### Advanced security topics
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
### no warranty
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

View File

@@ -3,64 +3,97 @@
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"name": "owner",
"type": "address"
}
],
"name": "OwnableInvalidOwner",
"type": "error"
},
{
"inputs": [
{
"internalType": "address",
"name": "account",
"type": "address"
}
],
"name": "OwnableUnauthorizedAccount",
"type": "error"
},
{
"anonymous": false,
"inputs": [
{
"indexed": false,
"internalType": "address",
"name": "addr",
"type": "address"
}
],
"name": "DeployedVerifier",
"type": "event"
},
{
"anonymous": false,
"inputs": [
{
"indexed": true,
"internalType": "address",
"name": "previousOwner",
"type": "address"
},
{
"indexed": true,
"internalType": "address",
"name": "newOwner",
"type": "address"
}
],
"name": "OwnershipTransferred",
"type": "event"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "_callData",
"name": "bytecode",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
}
],
"name": "deployVerifier",
"outputs": [
{
"internalType": "address",
"name": "_admin",
"name": "addr",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
"type": "function"
},
{
"inputs": [],
"name": "accountCall",
"name": "owner",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"name": "",
"type": "address"
},
{
"internalType": "bytes",
"name": "callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"inputs": [
{
"internalType": "bytes",
"name": "bytecode",
"type": "bytes"
}
],
"name": "precomputeAddress",
"outputs": [
{
"internalType": "address",
@@ -73,67 +106,33 @@
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
"name": "renounceOwnership",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint8",
"internalType": "address",
"name": "newOwner",
"type": "address"
}
],
"name": "transferOwnership",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"name": "verifierAddresses",
"outputs": [
{
"internalType": "bool",

View File

@@ -1,23 +1,4 @@
[
{
"inputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"name": "check_is_valid_field_element",
"outputs": [
{
"internalType": "uint256[]",
"name": "output",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
@@ -36,41 +17,12 @@
"type": "uint256[]"
}
],
"name": "quantize_data_multi",
"name": "quantize_data",
"outputs": [
{
"internalType": "int256[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "scales",
"type": "uint256[]"
}
],
"name": "quantize_data_single",
"outputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
"type": "int64[]"
}
],
"stateMutability": "pure",

View File

@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
K,
&LookupOp::Sigmoid { scale: 1.0.into() },
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
@@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())

View File

@@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
k,
&LookupOp::Sigmoid { scale: 1.0.into() },
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
@@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())

View File

@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let mut config = Config::default();
@@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())

View File

@@ -68,14 +68,7 @@ impl Circuit<Fr> for NLCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.unwrap();
Ok(())
},

View File

@@ -1,117 +0,0 @@
use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2curves::{bn256::Fr as F, ff::Field};
use maybe_rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
slice::ParallelSlice,
};
use rand::Rng;
// Assuming these are your types
#[derive(Clone)]
#[allow(dead_code)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),
Other,
}
// Helper to generate test data
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value
}
})
.collect()
}
fn bench_zero_finding(c: &mut Criterion) {
let sizes = [
1_000, // 1K
10_000, // 10K
100_000, // 100K
256 * 256 * 2, // Our specific case
1_000_000, // 1M
10_000_000, // 10M
];
let zero_probability = 0.1; // 10% zeros
let mut group = c.benchmark_group("zero_finding");
group.sample_size(10); // Adjust based on your needs
for &size in &sizes {
let data = generate_test_data(size, zero_probability);
// Benchmark sequential version
group.bench_function(format!("sequential_{}", size), |b| {
b.iter(|| {
let result = data
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark parallel version
group.bench_function(format!("parallel_{}", size), |b| {
b.iter(|| {
let result = data
.par_iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark chunked parallel version
group.bench_function(format!("chunked_parallel_{}", size), |b| {
b.iter(|| {
let num_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (size / num_cores).max(100);
let result = data
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>();
black_box(result)
})
});
}
group.finish();
}
criterion_group!(benches, bench_zero_finding);
criterion_main!(benches);

View File

@@ -163,253 +163,6 @@ contract SwapProofCommitments {
} /// end checkKzgCommits
}
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
bytes callData;
uint256 decimals;
}
AccountCall public accountCall;
uint[] scales;
address public admin;
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_LEN = 0;
uint256 constant OUTPUT_LEN = 0;
uint8 public instanceOffset;
/**
* @dev Initialize the contract with account calls the EZKL model will read from.
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
*/
constructor(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) internal {
AccountCall memory _accountCall = accountCall;
_accountCall.contractAddress = _contractAddresses;
_accountCall.callData = _callData;
_accountCall.decimals = 10 ** _decimals;
accountCall = _accountCall;
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
if (prod1 == 0) {
return prod0 / denominator;
}
require(denominator > prod1, "Math: mulDiv overflow");
uint256 remainder;
assembly {
remainder := mulmod(x, y, denominator)
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
return result;
}
}
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param x - One of the elements of the data returned from the account calls
* @param _decimals - Number of base 10 decimals to scale the data by.
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
*
*/
function quantizeData(
int x,
uint256 _decimals,
uint256 _scale
) internal pure returns (int256 quantized_data) {
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), _scale, _decimals);
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
output += 1;
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
* @param target - The address of the account to make calls to.
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
*/
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
require(
target.code.length > 0,
"Address: call to non-contract"
);
}
return returndata;
} else {
revert("Address: low-level call failed");
}
}
/**
* @dev Convert the fixed point quantized data into a field element.
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
}
/**
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_LEN + OUTPUT_LEN,
"Invalid public inputs length"
);
AccountCall memory _accountCall = accountCall;
uint[] memory _scales = scales;
bytes memory returnData = staticCall(
_accountCall.contractAddress,
_accountCall.callData
);
int256[] memory x = abi.decode(returnData, (int256[]));
uint _offset;
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
uint field_element = toFieldElement(output);
for (uint i = 0; i < x.length; i++) {
if (field_element != instances[i + instanceOffset]) {
_offset += 1;
} else {
break;
}
}
uint length = x.length - _offset;
for (uint i = 1; i < length; i++) {
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
field_element = toFieldElement(output);
require(
field_element == instances[i + instanceOffset + _offset],
"Public input does not match"
);
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
@@ -420,11 +173,11 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
contract DataAttestation is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to

View File

@@ -0,0 +1,184 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.20;
// lib/openzeppelin-contracts/contracts/utils/Context.sol
// OpenZeppelin Contracts (last updated v5.0.1) (utils/Context.sol)
/**
* @dev Provides information about the current execution context, including the
* sender of the transaction and its data. While these are generally available
* via msg.sender and msg.data, they should not be accessed in such a direct
* manner, since when dealing with meta-transactions the account sending and
* paying for execution may not be the actual sender (as far as an application
* is concerned).
*
* This contract is only required for intermediate, library-like contracts.
*/
abstract contract Context {
function _msgSender() internal view virtual returns (address) {
return msg.sender;
}
function _msgData() internal view virtual returns (bytes calldata) {
return msg.data;
}
function _contextSuffixLength() internal view virtual returns (uint256) {
return 0;
}
}
// lib/openzeppelin-contracts/contracts/access/Ownable.sol
// OpenZeppelin Contracts (last updated v5.0.0) (access/Ownable.sol)
/**
* @dev Contract module which provides a basic access control mechanism, where
* there is an account (an owner) that can be granted exclusive access to
* specific functions.
*
* The initial owner is set to the address provided by the deployer. This can
* later be changed with {transferOwnership}.
*
* This module is used through inheritance. It will make available the modifier
* `onlyOwner`, which can be applied to your functions to restrict their use to
* the owner.
*/
abstract contract Ownable is Context {
/// set the owener initialy to be the anvil test account
address private _owner = 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266;
/**
* @dev The caller account is not authorized to perform an operation.
*/
error OwnableUnauthorizedAccount(address account);
/**
* @dev The owner is not a valid owner account. (eg. `address(0)`)
*/
error OwnableInvalidOwner(address owner);
event OwnershipTransferred(
address indexed previousOwner,
address indexed newOwner
);
/**
* @dev Initializes the contract setting the address provided by the deployer as the initial owner.
*/
constructor() {
_transferOwnership(msg.sender);
}
/**
* @dev Throws if called by any account other than the owner.
*/
modifier onlyOwner() {
_checkOwner();
_;
}
/**
* @dev Returns the address of the current owner.
*/
function owner() public view virtual returns (address) {
return _owner;
}
/**
* @dev Throws if the sender is not the owner.
*/
function _checkOwner() internal view virtual {
if (owner() != _msgSender()) {
revert OwnableUnauthorizedAccount(_msgSender());
}
}
/**
* @dev Leaves the contract without owner. It will not be possible to call
* `onlyOwner` functions. Can only be called by the current owner.
*
* NOTE: Renouncing ownership will leave the contract without an owner,
* thereby disabling any functionality that is only available to the owner.
*/
function renounceOwnership() public virtual onlyOwner {
_transferOwnership(address(0));
}
/**
* @dev Transfers ownership of the contract to a new account (`newOwner`).
* Can only be called by the current owner.
*/
function transferOwnership(address newOwner) public virtual onlyOwner {
if (newOwner == address(0)) {
revert OwnableInvalidOwner(address(0));
}
_transferOwnership(newOwner);
}
/**
* @dev Transfers ownership of the contract to a new account (`newOwner`).
* Internal function without access restriction.
*/
function _transferOwnership(address newOwner) internal virtual {
address oldOwner = _owner;
_owner = newOwner;
emit OwnershipTransferred(oldOwner, newOwner);
}
}
// interface for the reusable verifier.
interface Halo2VerifierReusable {
function verifyProof(
address vkArtifact,
bytes calldata proof,
uint256[] calldata instances
) external returns (bool);
}
// Manages the deployment of all EZKL reusbale verifiers (ezkl version specific), verifiying key artifacts (circuit specific) and
// routing proof verifications to the correct VKA and associate reusable verifier.
// Helps to prevent the deployment of duplicate verifiers.
contract EZKLVerifierManager is Ownable {
/// @dev Mapping that checks if a given reusable verifier has been deployed
mapping(address => bool) public verifierAddresses;
event DeployedVerifier(address addr);
// 1. Compute the address of the verifier to be deployed
function precomputeAddress(
bytes memory bytecode
) public view returns (address) {
bytes32 hash = keccak256(
abi.encodePacked(
bytes1(0xff),
address(this),
uint(0),
keccak256(bytecode)
)
);
return address(uint160(uint(hash)));
}
// 2. Deploy the reusable verifier using create2
/// @param bytecode The bytecode of the reusable verifier to deploy
function deployVerifier(
bytes memory bytecode
) public returns (address addr) {
assembly {
addr := create2(
0x0, // value, hardcode to 0
add(bytecode, 0x20),
mload(bytecode),
0x0 // salt, hardcode to 0
)
if iszero(extcodesize(addr)) {
revert(0, 0)
}
}
verifierAddresses[addr] = true;
emit DeployedVerifier(addr);
}
}

View File

@@ -1,41 +0,0 @@
## EZKL Security Note: Public Commitments and Low-Entropy Data
> **Disclaimer:** this a more technical post that requires some prior knowledge of how ZK proving systems like Halo2 operate, and in particular in how these APIs are constructed. For background reading we highly recommend the [Halo2 book](https://zcash.github.io/halo2/) and [Halo2 Club](https://halo2.club/).
## Overview of commitments in EZKL
A common design pattern in a zero knowledge (zk) application is thus:
- A prover has some data which is used within a circuit.
- This data, as it may be high-dimensional or somewhat private, is pre-committed to using some hash function.
- The zk-circuit which forms the core of the application then proves (para-phrasing) a statement of the form:
>"I know some data D which when hashed corresponds to the pre-committed to value H + whatever else the circuit is proving over D".
From our own experience, we've implemented such patterns using snark-friendly hash functions like [Poseidon](https://www.poseidon-hash.info/), for which there is a relatively well vetted [implementation](https://docs.rs/halo2_gadgets/latest/halo2_gadgets/poseidon/index.html) in Halo2. Even then these hash functions can introduce lots of overhead and can be very expensive to generate proofs for if the dimensionality of the data D is large.
You can also implement such a pattern using Halo2's `Fixed` columns _if the privacy preservation of the pre-image is not necessary_. These are Halo2 columns (i.e in reality just polynomials) that are left unblinded (unlike the blinded `Advice` columns), and whose commitments are shared with the verifier by way of the verifying key for the application's zk-circuit. These commitments are much lower cost to generate than implementing a hashing function, such as Poseidon, within a circuit.
> **Note:** Blinding is the process whereby a certain set of the final elements (i.e rows) of a Halo2 column are set to random field elements. This is the mechanism by which Halo2 achieves its zero knowledge properties for `Advice` columns. By contrast `Fixed` columns aren't zero-knowledge in that they are vulnerable to dictionary attacks in the same manner a hash function is. Given some set of known or popular data D an attacker can attempt to recover the pre-image of a hash by running D through the hash function to see if the outputs match a public commitment. These attacks aren't "possible" on blinded `Advice` columns.
> **Further Note:** Note that without blinding, with access to `M` proofs, each of which contains an evaluation of the polynomial at a different point, an attacker can more easily recover a non blinded column's pre-image. This is because each proof generates a new query and evaluation of the polynomial represented by the column and as such with repetition a clearer picture can emerge of the column's pre-image. Thus unblinded columns should only be used for privacy preservation, in the manner of a hash, if the number of proofs generated against a fixed set of values is limited. More formally if M independent and _unique_ queries are generated; if M is equal to the degree + 1 of the polynomial represented by the column (i.e the unique lagrange interpolation of the values in the columns), then the column's pre-image can be recovered. As such as the logrows K increases, the more queries are required to recover the pre-image (as 2^K unique queries are required). This assumes that the entries in the column are not structured, as if they are then the number of queries required to recover the pre-image is reduced (eg. if all rows above a certain point are known to be nil).
The annoyance in using `Fixed` columns comes from the fact that they require generating a new verifying key every time a new set of commitments is generated.
> **Example:** Say for instance an application leverages a zero-knowledge circuit to prove the correct execution of a neural network. Every week the neural network is finetuned or retrained on new data. If the architecture remains the same then commiting to the new network parameters, along with a new proof of performance on a test set, would be an ideal setup. If we leverage `Fixed` columns to commit to the model parameters, each new commitment will require re-generating a verifying key and sharing the new key with the verifier(s). This is not-ideal UX and can become expensive if the verifier is deployed on-chain.
An ideal commitment would thus have the low cost of a `Fixed` column but wouldn't require regenerating a new verifying key for each new commitment.
### Unblinded Advice Columns
A first step in designing such a commitment is to allow for optionally unblinded `Advice` columns within the Halo2 API. These won't be included in the verifying key, AND are blinded with a constant factor `1` -- such that if someone knows the pre-image to the commitment, they can recover it by running it through the corresponding polynomial commitment scheme (in ezkl's case [KZG commitments](https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html)).
This is implemented using the `polycommit` visibility parameter in the ezkl API.
## The Vulnerability of Public Commitments
Public commitments in EZKL (both Poseidon-hashed inputs and KZG commitments) can be vulnerable to brute-force attacks when input data has low entropy. A malicious actor could reveal committed data by searching through possible input values, compromising privacy in applications like anonymous credentials. This is particularly relevant when input data comes from known finite sets (e.g., names, dates).
Example Risk: In an anonymous credential system using EZKL for ID verification, an attacker could match hashed outputs against a database of common identifying information to deanonymize users.

View File

@@ -1,22 +0,0 @@
# EZKL Security Note: Quantization-Induced Model Backdoors
> Note: this only affects a situation where a party separate to an application's developer has access to the model's weights and can modify them. This is a common scenario in adversarial machine learning research, but can be less common in real-world applications. If you're building your models in house and deploying them yourself, this is less of a concern. If you're building a permisionless system where anyone can submit models, this is more of a concern.
Models processed through EZKL's quantization step can harbor backdoors that are dormant in the original full-precision model but activate during quantization. These backdoors force specific outputs when triggered, with impact varying by application.
Key Factors:
- Larger models increase attack feasibility through more parameter capacity
- Smaller quantization scales facilitate attacks by allowing greater weight modifications
- Rebase ratio of 1 enables exploitation of convolutional layer consistency
Limitations:
- Attack effectiveness depends on calibration settings and internal rescaling operations.
- Further research needed on backdoor persistence through witness/proof stages.
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model.
References:
1. [Quantization Backdoors to Deep Learning Commercial Frameworks (Ma et al., 2021)](https://arxiv.org/abs/2108.09187)
2. [Planting Undetectable Backdoors in Machine Learning Models (Goldwasser et al., 2022)](https://arxiv.org/abs/2204.06974)

View File

@@ -1,4 +1,4 @@
ezkl
ezkl==0.0.0
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -146,8 +146,6 @@ where
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::configure(
@@ -158,11 +156,15 @@ where
);
layer_config
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
layer_config
@@ -193,11 +195,6 @@ where
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
@@ -227,10 +224,7 @@ where
.layout(
&mut region,
&[x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();

View File

@@ -53,10 +53,6 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
let output = VarTensor::new_advice(cs, K, 1, LEN);
// tells the config layer to add an affine op to the circuit gate
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
println!("INPUT COL {:#?}", input);
let mut layer_config = PolyConfig::<F>::configure(
cs,
&[input.clone(), params.clone()],
@@ -64,12 +60,17 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
CheckMode::SAFE,
);
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();
layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
@@ -103,11 +104,6 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();
config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();
let x = layouter
.assign_region(
|| "mlp_4d",
@@ -148,10 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap()
.unwrap();
@@ -191,10 +184,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
println!("6");

View File

@@ -592,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.2"
},
"orig_nbformat": 4
},

View File

@@ -648,10 +648,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.9.15"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}

File diff suppressed because one or more lines are too long

View File

@@ -77,7 +77,6 @@
"outputs": [],
"source": [
"gip_run_args = ezkl.PyRunArgs()\n",
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
@@ -336,9 +335,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1,284 +1,279 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -1,462 +1,456 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -453,18 +453,18 @@
"outputs": [],
"source": [
"# now mock aggregate the proofs\n",
"# proofs = []\n",
"# for i in range(3):\n",
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
"# proofs.append(proof_path)\n",
"proofs = []\n",
"for i in range(3):\n",
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
@@ -478,7 +478,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.5"
},
"orig_nbformat": 4
},

View File

@@ -1,766 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
"\n",
"1 - NBoW\n",
"\n",
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Preparing Data\n",
"\n",
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
"\n",
"The steps to take are:\n",
"\n",
" 1. importing modules\n",
" 2. loading data\n",
" 3. tokenizing data\n",
" 4. creating data splits\n",
" 5. creating a vocabulary\n",
" 6. numericalizing data\n",
" 7. creating the data loaders\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install torchtex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchtext\n",
"import tqdm\n",
"\n",
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
"\n",
"seed = 1234\n",
"\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.features\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
"\n",
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_example(example, tokenizer, max_length):\n",
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
" return {\"tokens\": tokens}\n",
"\n",
"\n",
"max_length = 256\n",
"\n",
"train_data = train_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"test_data = test_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"\n",
"\n",
"# create validation data \n",
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
"\n",
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data[\"train\"]\n",
"valid_data = train_valid_data[\"test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
"\n",
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
" train_data[\"tokens\"],\n",
" min_freq=min_freq,\n",
" specials=special_tokens,\n",
")\n",
"\n",
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
"\n",
"unk_index = vocab[\"<unk>\"]\n",
"pad_index = vocab[\"<pad>\"]\n",
"\n",
"\n",
"vocab.set_default_index(unk_index)\n",
"\n",
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def numericalize_example(example, vocab):\n",
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
" return {\"ids\": ids}\n",
"\n",
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"\n",
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
"\n",
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
"\n",
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
"\n",
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
"\n",
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
"\n",
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_collate_fn(pad_index):\n",
" def collate_fn(batch):\n",
" batch_ids = [i[\"ids\"] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(\n",
" batch_ids, padding_value=pad_index, batch_first=True\n",
" )\n",
" batch_label = [i[\"label\"] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
" return batch\n",
"\n",
" return collate_fn\n",
"\n",
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
" collate_fn = get_collate_fn(pad_index)\n",
" data_loader = torch.utils.data.DataLoader(\n",
" dataset=dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_fn,\n",
" shuffle=shuffle,\n",
" )\n",
" return data_loader\n",
"\n",
"\n",
"batch_size = 512\n",
"\n",
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction\n",
"\n",
"\n",
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique(\"label\"))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
"\n",
"vectors = torchtext.vocab.GloVe()\n",
"\n",
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(data_loader, model, criterion, optimizer, device):\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(data_loader, model, criterion, device):\n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" with torch.no_grad():\n",
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float(\"inf\")\n",
"\n",
"metrics = collections.defaultdict(list)\n",
"\n",
"for epoch in range(n_epochs):\n",
" train_loss, train_acc = train(\n",
" train_data_loader, model, criterion, optimizer, device\n",
" )\n",
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
" metrics[\"train_losses\"].append(train_loss)\n",
" metrics[\"train_accs\"].append(train_acc)\n",
" metrics[\"valid_losses\"].append(valid_loss)\n",
" metrics[\"valid_accs\"].append(valid_acc)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), \"nbow.pt\")\n",
" print(f\"epoch: {epoch}\")\n",
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
"\n",
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
"\n",
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def text_to_tensor(text, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" return tensor\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we do onnx stuff to get the data ready for the zk-circuit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json\n",
"\n",
"text = \"This film is terrible!\"\n",
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"model.to('cpu')\n",
"\n",
"model_path = \"network.onnx\"\n",
"data_path = \"input.json\"\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data_json = dict(input_data = [data_array])\n",
"\n",
"print(data_json)\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data_json, open(data_path, 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.logrows = 23\n",
"run_args.scale_rebase_multiplier = 10\n",
"# inputs should be auditable by all\n",
"run_args.input_visibility = \"public\"\n",
"# same with outputs\n",
"run_args.output_visibility = \"public\"\n",
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
"run_args.param_visibility = \"fixed\"\n",
"\n",
"\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(py_run_args=run_args)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file\n",
"res = await ezkl.gen_witness()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.mock()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup()\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"res = ezkl.prove(proof_path=\"proof.json\")\n",
"\n",
"print(res)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also verify it on chain by creating an onchain verifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
" !solc-select install 0.8.20\n",
" !solc-select use 0.8.20\n",
" !solc --version\n",
" import os\n",
"\n",
"# rely on local installation if the notebook is not in colab\n",
"except:\n",
" import os\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.create_evm_verifier()\n",
"assert res == True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
"\n",
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
"\n",
"Create a new file within remix and copy the verifier code over.\n",
"\n",
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
"\n",
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -152,11 +152,9 @@
"metadata": {},
"outputs": [],
"source": [
"run_args = ezkl.PyRunArgs()\n",
"# logrows\n",
"run_args.logrows = 20\n",
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
@@ -304,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -167,8 +167,6 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"hashed/private/0\"\n",
"# as the inputs are felts we turn off input range checks\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# we set it to fix the set we want to check membership for\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public -- set membership fails if it is not = 0\n",
@@ -521,4 +519,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"id": "171702d3",
"metadata": {},
"outputs": [],
@@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"id": "671dfdd5",
"metadata": {},
"outputs": [],
@@ -364,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": null,
"id": "50eba2f4",
"metadata": {},
"outputs": [],
@@ -399,9 +399,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -204,7 +204,6 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# the parameters are public\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public (this is the inequality test)\n",
@@ -515,4 +514,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1,764 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# univ3-da-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source. For this setup we make a single call to a view function that returns an array of UniV3 historical TWAP price data that we will attest to on-chain. \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--fork-url\", \"https://arb1.arbitrum.io/rpc\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.decomp_legs=6\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"import requests\n",
"\n",
"# This function counts the decimal places of a floating point number\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
" return len(num_str) - 1 - num_str.index('.')\n",
" else:\n",
" return 0\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"def set_next_block_timestamp(anvil_url, timestamp):\n",
" # Send the JSON-RPC request to Anvil\n",
" payload = {\n",
" \"jsonrpc\": \"2.0\",\n",
" \"id\": 1,\n",
" \"method\": \"evm_setNextBlockTimestamp\",\n",
" \"params\": [timestamp]\n",
" }\n",
" response = requests.post(anvil_url, json=payload)\n",
" if response.status_code == 200:\n",
" print(f\"Next block timestamp set to: {timestamp}\")\n",
" else:\n",
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
"\n",
"def on_chain_data(tensor):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = tensor.view(-1).tolist()\n",
"\n",
" # Step 1: Prepare the calldata\n",
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
"\n",
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: MIT\n",
" pragma solidity ^0.8.20;\n",
"\n",
" /// @title Pool state that is not stored\n",
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
" /// blockchain. The functions here may have variable gas costs.\n",
" interface IUniswapV3PoolDerivedState {\n",
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
" /// you must call it with secondsAgos = [3600, 0].\n",
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
" /// timestamp\n",
" function observe(\n",
" uint32[] calldata secondsAgos\n",
" )\n",
" external\n",
" view\n",
" returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" }\n",
"\n",
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
" /// @notice Provides functions to integrate with V3 pool oracle\n",
" contract UniTickAttestor {\n",
" /**\n",
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
" * @param pool Address of the pool that we want to observe\n",
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
" */\n",
" function consult(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public view returns (int256[] memory tickCumulatives) {\n",
" tickCumulatives = new int256[](secondsAgo.length);\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" for (uint256 i = 0; i < secondsAgo.length; i++) {\n",
" tickCumulatives[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"UniTickAttestor.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = UniTickAttestor.constructor().transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" call = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).build_transaction()\n",
" result = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).call()\n",
" \n",
" print(f'result: {result}')\n",
" calldata = call['data'][2:]\n",
"\n",
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
" print(f'time_stamp: {time_stamp}')\n",
"\n",
" # Set the next block timestamp using the fetched time_stamp\n",
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
"\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': 0,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" 'len': len(data),\n",
" }\n",
"\n",
" print(f'call_to_account: {call_to_account}')\n",
"\n",
" return call_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we need to regenerate the witness, prove and then verify all within the same cell. This is because we want to reduce the amount of latency between reading on-chain state and verifying it on-chain. This is because the attest input values read from the oracle are time sensitive (their values are derived from computing on block.timestamp) and can change between the time of reading and the time of verifying.\n",
"\n",
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
"print(f'time_stamp: {time_stamp}')\n",
"\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)\n",
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,42 +0,0 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.exp(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[0.5801457762718201, 0.6019012331962585, 0.8695418238639832, 0.17170941829681396, 0.500616729259491, 0.353726327419281, 0.6726185083389282, 0.5936906337738037]]}

View File

@@ -1,14 +0,0 @@
pytorch2.2.2:o

inputoutput/Exp"Exp
main_graphZ!
input


batch_size
b"
output


batch_size
B

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -1,41 +0,0 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = 10**x
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[0.9837989807128906, 0.026381194591522217, 0.3403851389884949, 0.14531707763671875, 0.24652725458145142, 0.7945117354393005, 0.4076554775238037, 0.23064672946929932]]}

View File

@@ -1,42 +0,0 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
return x // 3
circuit = MyModel()
x = torch.randint(0, 10, (1, 2, 2, 8))
out = circuit(x)
print(x)
print(out)
print(x/3)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}

View File

@@ -1,42 +0,0 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.log(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 3)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}

View File

@@ -1,14 +0,0 @@
pytorch2.2.2:o

inputoutput/Log"Log
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -21,9 +21,9 @@ def main():
torch_model = Circuit()
# Input to the model
shape = [3, 2, 3]
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
w = 0.1*torch.rand(1, *shape, requires_grad=True)
x = 0.1*torch.rand(1, *shape, requires_grad=True)
y = 0.1*torch.rand(1, *shape, requires_grad=True)
torch_out = torch_model(w, x, y)
# Export the model
torch.onnx.export(torch_model, # model being run

View File

@@ -1,148 +1 @@
{
"input_shapes": [
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
],
[
3,
2,
3
]
],
"input_data": [
[
0.5,
1.5,
-0.04514765739440918,
0.5936200618743896,
0.9271858930587769,
0.6688600778579712,
-0.20331168174743652,
-0.7016235589981079,
0.025863051414489746,
-0.19426143169403076,
0.9827852249145508,
0.4897397756576538,
-1.5,
-0.5,
0.9278832674026489,
0.5943725109100342,
-0.573331356048584,
0.3675816059112549
],
[
0.7803324460983276,
-0.9616303443908691,
0.6070173978805542,
-0.028337717056274414,
-0.5080242156982422,
-0.9280107021331787,
0.6150380373001099,
0.3865993022918701,
-0.43668973445892334,
0.17152702808380127,
0.5144252777099609,
-0.28881049156188965,
0.8932310342788696,
0.059034109115600586,
0.6865451335906982,
0.009820222854614258,
0.23011493682861328,
-0.9492779970169067
],
[
-0.21352827548980713,
-0.16015326976776123,
-0.38964390754699707,
0.13464701175689697,
-0.8814496994018555,
0.5037975311279297,
-0.804405927658081,
0.9858957529067993,
0.19567716121673584,
0.9777265787124634,
0.6151977777481079,
0.568595290184021,
0.10584986209869385,
-0.8975653648376465,
0.6235959529876709,
-0.547879695892334,
0.9289869070053101,
0.7567293643951416
]
],
"output_data": [
[
1.0,
0.0,
-0.0,
1.0,
1.0,
1.0,
-0.0,
-1.0,
0.0,
-0.0,
1.0,
0.0,
0.0,
1.0,
1.0,
1.0,
-1.0,
0.0
],
[
0.0,
-1.0,
0.0,
-1.0,
-1.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
-1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-1.0
],
[
-0.0,
-0.0,
-0.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
-0.0,
1.0,
-0.0,
1.0,
1.0
]
]
}
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}

View File

@@ -1,11 +1,10 @@
pytorch2.2.2:ă
pytorch2.0.1:â

woutput_w/Round"Round

xoutput_x/Floor"Floor

youtput_y/Ceil"Ceil
main_graphZ%
youtput_y/Ceil"Ceil torch_jitZ%
w



View File

@@ -1,42 +0,0 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
# reciprocal sqrt
m = 1 / torch.sqrt(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}

View File

@@ -1,17 +0,0 @@
pytorch2.2.2:Ź
$
input/Sqrt_output_0/Sqrt"Sqrt
1
/Sqrt_output_0output /Reciprocal"
Reciprocal
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -1,52 +1,75 @@
from torch import nn
import torch
import torch.nn as nn
import sys
import json
import numpy as np
import tf2onnx
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
# gather_nd in tf then export to onnx
x = in1 = Input((4, 1), dtype=tf.int32)
w = in2 = Input((4, ), dtype=tf.int32)
class MyLayer(Layer):
def call(self, x, w):
shape = tf.constant([8])
return tf.scatter_nd(x, w, shape)
x = MyLayer()(x, w)
tm = Model((in1, in2), x)
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 4, 1]
index_shape = [1, 4]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = np.random.randint(0, 4, shape)
# w = random int tensor
w = np.random.randint(0, 4, index_shape)
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
input_data=[d1],
)
# Serialize data into file:

View File

@@ -1,16 +1 @@
{
"input_data": [
[
0,
1,
2,
3
],
[
1,
0,
2,
1
]
]
}
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

851
ezkl.pyi
View File

@@ -1,851 +0,0 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401
import os
import pathlib
import typing
from enum import Enum, auto
class PyG1:
r"""
pyclass containing the struct used for G1, this is mostly a helper class
"""
...
class PyG1Affine:
r"""
pyclass containing the struct used for G1
"""
...
class PyRunArgs:
r"""
Python class containing the struct used for run_args
Returns
-------
PyRunArgs
"""
...
class PyCommitments(Enum):
r"""
pyclass representing an enum, denoting the type of commitment
"""
KZG = auto()
IPA = auto()
class PyInputType(Enum):
Bool = auto()
F16 = auto()
F32 = auto()
F64 = auto()
Int = auto()
TDim = auto()
class PyTestDataSource(Enum):
r"""
pyclass representing an enum
"""
File = auto()
OnChain = auto()
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
r"""
Creates an aggregated proof
Arguments
---------
aggregation_snarks: list[str]
List of paths to the various proofs
proof_path: str
Path to output the aggregated proof
vk_path: str
Path to the VK file
transcript:
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
logrows:
Logrows used for aggregation circuit
check_mode: str
Run sanity checks during calculations. Accepts `safe` or `unsafe`
split-proofs: bool
Whether the accumulated proofs are segments of a larger circuit
srs_path: str
Path to the SRS used
commitment: str
Accepts "kzg" or "ipa"
Returns
-------
bool
"""
...
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
r"""
Converts a buffer to vector of field elements
Arguments
-------
buffer: list[int]
List of integers representing a buffer
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
r"""
Calibrates the circuit settings
Arguments
---------
data: str
Path to the calibration data
model: str
Path to the onnx file
settings: str
Path to the settings file
lookup_safety_margin: int
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
scales: list[int]
Optional scales to specifically try for calibration
scale_rebase_multiplier: list[int]
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
max_logrows: int
Optional max logrows to use for calibration
Returns
-------
bool
"""
...
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Compiles the circuit for use in other steps
Arguments
---------
model: str
Path to the onnx model file
compiled_circuit: str
Path to output the compiled circuit
settings_path: str
Path to the settings files
Returns
-------
bool
"""
...
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
Arguments
---------
input_data: str
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
Returns
-------
bool
"""
...
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
Arguments
---------
aggregation_settings: str
path to the settings file
vk_path: str
The path to load the desired verification key file
sol_code_path: str
The path to the Solidity code
abi_path: str
The path to output the Solidity verifier ABI
logrows: int
Number of logrows used during aggregated setup
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifying key.
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity da verifier
"""
...
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity verifier
"""
...
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
r"""
Creates encoded evm calldata from a proof file
Arguments
---------
proof: str
Path to the proof file
calldata: str
Path to the calldata file to save
addr_vk: str
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
Returns
-------
vec[u8]
The encoded calldata
"""
...
def felt_to_big_endian(felt:str) -> str:
r"""
Converts a field element hex string to big endian
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
str
field element represented as a string
"""
...
def felt_to_float(felt:str,scale:int) -> float:
r"""
Converts a field element hex string to a floating point number
Arguments
-------
felt: str
The field element represented as a string
scale: float
The scaling factor used to convert the field element into a floating point representation
Returns
-------
float
"""
...
def felt_to_int(felt:str) -> int:
r"""
Converts a field element hex string to an integer
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
int
"""
...
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
r"""
Converts a floating point element to a field element hex string
Arguments
-------
input: float
The field element represented as a string
scale: float
The scaling factor used to quantize the float into a field element
input_type: PyInputType
The type of the input
Returns
-------
str
The field element represented as a string
"""
...
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
r"""
Generates the circuit settings
Arguments
---------
model: str
Path to the onnx file
output: str
Path to create the settings file
py_run_args: PyRunArgs
PyRunArgs object to initialize the settings
Returns
-------
bool
"""
...
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
r"""
Generates the Structured Reference String (SRS), use this only for testing purposes
Arguments
---------
srs_path: str
Path to the create the SRS file
logrows: int
The number of logrows for the SRS file
"""
...
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for an aggregate circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for a model circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
circuit_settings_path: str
Path to the witness file
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the forward pass operation to generate a witness
Arguments
---------
data: str
Path to the data file
model: str
Path to the compiled model file
output: str
Path to create the witness file
vk_path: str
Path to the verification key
srs_path: str
Path to the SRS file
Returns
-------
dict
Python object containing the witness values
"""
...
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
r"""
Gets a public srs
Arguments
---------
settings_path: str
Path to the settings file
logrows: int
The number of logrows for the SRS file
srs_path: str
Path to the create the SRS file
commitment: str
Specify the commitment used ("kzg", "ipa")
Returns
-------
bool
"""
...
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate an ipa commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate a kzg commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
r"""
Mocks the prover
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
Returns
-------
bool
"""
...
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
r"""
Mocks the aggregate prover
Arguments
---------
aggregation_snarks: list[str]
List of paths to the relevant proof files
logrows: int
Number of logrows to use for the aggregation circuit
split_proofs: bool
Indicates whether the accumulated are segments of a larger proof
Returns
-------
bool
"""
...
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
r"""
Generate a poseidon hash.
Arguments
-------
message: list[str]
List of field elements represented as strings
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the prover on a set of inputs
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
pk_path: str
Path to the proving key file
proof_path: str
Path to create the proof file
proof_type: str
Accepts `single`, `for-aggr`
srs_path: str
Path to the SRS file
Returns
-------
bool
"""
...
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
r"""
Runs the setup process
Arguments
---------
model: str
Path to the compiled model file
vk_path: str
Path to create the verification key file
pk_path: str
Path to create the proving key file
srs_path: str
Path to the SRS file
witness_path: str
Path to the witness file
disable_selector_compression: bool
Whether to compress the selectors or not
Returns
-------
bool
"""
...
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
r"""
Runs the setup process for an aggregate setup
Arguments
---------
sample_snarks: list[str]
List of paths to the various proofs
vk_path: str
Path to create the aggregated VK
pk_path: str
Path to create the aggregated PK
logrows: int
Number of logrows to use
split_proofs: bool
Whether the accumulated are segments of a larger proof
srs_path: str
Path to the SRS file
disable_selector_compression: bool
Whether to compress selectors
commitment: str
Accepts `kzg`, `ipa`
Returns
-------
bool
"""
...
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
r"""
Setup test evm witness
Arguments
---------
data_path: str
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
compiled_circuit_path: str
The path to the compiled model file (generated using the compile-circuit command)
test_data: str
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
input_sources: str
Where the input data comes from
output_source: str
Where the output data comes from
rpc_url: str
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
Returns
-------
bool
"""
...
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
r"""
Swap the commitments in a proof
Arguments
-------
proof_path: str
Path to the proof file
witness_path: str
Path to the witness file
"""
...
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
r"""
Displays the table as a string in python
Arguments
---------
model: str
Path to the onnx file
Returns
---------
str
Table of the nodes in the onnx file
"""
...
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
r"""
Verifies a given proof
Arguments
---------
proof_path: str
Path to create the proof file
settings_path: str
Path to the settings file
vk_path: str
Path to the verification key file
srs_path: str
Path to the SRS file
non_reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
Returns
-------
bool
"""
...
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
r"""
Verifies and aggregate proof
Arguments
---------
proof_path: str
The path to the proof file
vk_path: str
The path to the verification key file
logrows: int
logrows used for aggregation circuit
commitment: str
Accepts "kzg" or "ipa"
reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
r"""
verifies an evm compatible proof, you will need solc installed in your environment to run this
Arguments
---------
addr_verifier: str
The verifier contract's address as a hex string
proof_path: str
The path to the proof file (generated using the prove command)
rpc_url: str
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
addr_da: str
does the verifier use data attestation ?
addr_vk: str
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
Returns
-------
bool
"""
...

View File

@@ -12,7 +12,6 @@ asyncio_mode = "auto"
[project]
name = "ezkl"
version = "0.0.0"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",

View File

@@ -1,11 +1,7 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};

View File

@@ -1,9 +0,0 @@
use pyo3_stub_gen::Result;
fn main() -> Result<()> {
// `stub_info` is a function defined by `define_stub_info_gatherer!` macro.
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = ezkl::bindings::python::stub_info()?;
stub.generate()?;
Ok(())
}

View File

@@ -4,10 +4,10 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
@@ -26,12 +26,7 @@ use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
use std::str::FromStr;
use std::{fs::File, path::PathBuf};
@@ -40,7 +35,6 @@ type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyTestDataSource {
/// The data is loaded from a file
File,
@@ -60,7 +54,6 @@ impl From<PyTestDataSource> for TestDataSource {
/// pyclass containing the struct used for G1, this is mostly a helper class
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
@@ -107,7 +100,6 @@ impl pyo3::ToPyObject for PyG1 {
/// pyclass containing the struct used for G1
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
pub struct PyG1Affine {
#[pyo3(get, set)]
///
@@ -153,7 +145,6 @@ impl pyo3::ToPyObject for PyG1Affine {
///
#[pyclass]
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
@@ -189,6 +180,9 @@ struct PyRunArgs {
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
@@ -203,12 +197,6 @@ struct PyRunArgs {
/// int: The number of legs used for decomposition
#[pyo3(get, set)]
pub decomp_legs: usize,
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
}
/// default instantiation of PyRunArgs
@@ -224,7 +212,6 @@ impl PyRunArgs {
impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
@@ -236,12 +223,12 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
}
}
}
@@ -249,7 +236,6 @@ impl From<PyRunArgs> for RunArgs {
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
@@ -261,19 +247,18 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
}
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
@@ -321,65 +306,6 @@ impl FromStr for PyCommitments {
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyInputType {
///
Bool,
///
F16,
///
F32,
///
F64,
///
Int,
///
TDim,
}
impl From<InputType> for PyInputType {
fn from(input_type: InputType) -> Self {
match input_type {
InputType::Bool => PyInputType::Bool,
InputType::F16 => PyInputType::F16,
InputType::F32 => PyInputType::F32,
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
}
}
}
impl From<PyInputType> for InputType {
fn from(py_input_type: PyInputType) -> Self {
match py_input_type {
PyInputType::Bool => InputType::Bool,
PyInputType::F16 => InputType::F16,
PyInputType::F32 => InputType::F32,
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
}
}
}
impl FromStr for PyInputType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bool" => Ok(PyInputType::Bool),
"f16" => Ok(PyInputType::F16),
"f32" => Ok(PyInputType::F32),
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
_ => Err("Invalid value for InputType".to_string()),
}
}
}
/// Converts a field element hex string to big endian
///
/// Arguments
@@ -396,7 +322,6 @@ impl FromStr for PyInputType {
#[pyfunction(signature = (
felt,
))]
#[gen_stub_pyfunction]
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
@@ -416,7 +341,6 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
#[pyfunction(signature = (
felt,
))]
#[gen_stub_pyfunction]
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
@@ -441,7 +365,6 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
felt,
scale
))]
#[gen_stub_pyfunction]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
@@ -460,9 +383,6 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// input_type: PyInputType
/// The type of the input
///
/// Returns
/// -------
/// str
@@ -470,12 +390,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
///
#[pyfunction(signature = (
input,
scale,
input_type=PyInputType::F64
scale
))]
#[gen_stub_pyfunction]
fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult<PyFelt> {
InputType::roundtrip(&input_type.into(), &mut input);
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = integer_rep_to_felt(int_rep);
@@ -497,7 +414,6 @@ fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -
#[pyfunction(signature = (
buffer
))]
#[gen_stub_pyfunction]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -570,14 +486,16 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
#[pyfunction(signature = (
message,
))]
#[gen_stub_pyfunction]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
.map_err(|_| PyIOError::new_err("Failed to run poseidon"))?;
let hash = output[0]
@@ -613,7 +531,6 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -672,7 +589,6 @@ fn kzg_commit(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -719,7 +635,6 @@ fn ipa_commit(
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
#[gen_stub_pyfunction]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -749,7 +664,6 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -787,7 +701,6 @@ fn gen_vk_from_pk_single(
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -817,7 +730,6 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
))]
#[gen_stub_pyfunction]
fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?;
@@ -843,7 +755,6 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
srs_path,
logrows,
))]
#[gen_stub_pyfunction]
fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
let params = ezkl_gen_srs::<KZGCommitmentScheme<Bn256>>(logrows as u32);
save_params::<KZGCommitmentScheme<Bn256>>(&srs_path, &params)?;
@@ -876,7 +787,6 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
srs_path=None,
commitment=None,
))]
#[gen_stub_pyfunction]
fn get_srs(
py: Python,
settings_path: Option<PathBuf>,
@@ -889,7 +799,7 @@ fn get_srs(
None => None,
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
.await
.map_err(|e| {
@@ -923,7 +833,6 @@ fn get_srs(
output=PathBuf::from(DEFAULT_SETTINGS),
py_run_args = None,
))]
#[gen_stub_pyfunction]
fn gen_settings(
model: PathBuf,
output: PathBuf,
@@ -939,45 +848,6 @@ fn gen_settings(
Ok(true)
}
/// Generates random data for the model
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the data file
///
/// seed: int
/// Random seed to use for generated data
///
/// variables
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn gen_random_data(
model: PathBuf,
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Calibrates the circuit settings
///
/// Arguments
@@ -1003,6 +873,8 @@ fn gen_random_data(
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
@@ -1017,8 +889,8 @@ fn gen_random_data(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: PathBuf,
@@ -1029,8 +901,9 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::calibrate(
model,
data,
@@ -1039,6 +912,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.await
@@ -1082,7 +956,6 @@ fn calibrate_settings(
vk_path=None,
srs_path=None,
))]
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: PathBuf,
@@ -1091,7 +964,7 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
.await
.map_err(|e| {
@@ -1120,7 +993,6 @@ fn gen_witness(
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
))]
#[gen_stub_pyfunction]
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
crate::execute::mock(model, witness).map_err(|e| {
let err_str = format!("Failed to run mock: {}", e);
@@ -1151,7 +1023,6 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
))]
#[gen_stub_pyfunction]
fn mock_aggregate(
aggregation_snarks: Vec<PathBuf>,
logrows: u32,
@@ -1199,7 +1070,6 @@ fn mock_aggregate(
witness_path = None,
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup(
model: PathBuf,
vk_path: PathBuf,
@@ -1258,7 +1128,6 @@ fn setup(
proof_type=ProofType::default(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn prove(
witness: PathBuf,
model: PathBuf,
@@ -1314,7 +1183,6 @@ fn prove(
srs_path=None,
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
#[gen_stub_pyfunction]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
@@ -1374,7 +1242,6 @@ fn verify(
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,
@@ -1425,7 +1292,6 @@ fn setup_aggregate(
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
))]
#[gen_stub_pyfunction]
fn compile_circuit(
model: PathBuf,
compiled_circuit: PathBuf,
@@ -1485,7 +1351,6 @@ fn compile_circuit(
srs_path=None,
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn aggregate(
aggregation_snarks: Vec<PathBuf>,
proof_path: PathBuf,
@@ -1551,7 +1416,6 @@ fn aggregate(
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn verify_aggr(
proof_path: PathBuf,
vk_path: PathBuf,
@@ -1599,7 +1463,6 @@ fn verify_aggr(
calldata=PathBuf::from(DEFAULT_CALLDATA),
addr_vk=None,
))]
#[gen_stub_pyfunction]
fn encode_evm_calldata<'a>(
proof: PathBuf,
calldata: PathBuf,
@@ -1652,7 +1515,6 @@ fn encode_evm_calldata<'a>(
srs_path=None,
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier(
py: Python,
vk_path: PathBuf,
@@ -1662,7 +1524,7 @@ fn create_evm_verifier(
srs_path: Option<PathBuf>,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_verifier(
vk_path,
srs_path,
@@ -1712,7 +1574,6 @@ fn create_evm_verifier(
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
#[gen_stub_pyfunction]
fn create_evm_vka(
py: Python,
vk_path: PathBuf,
@@ -1721,7 +1582,7 @@ fn create_evm_vka(
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
@@ -1760,7 +1621,6 @@ fn create_evm_vka(
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
witness_path=None,
))]
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: PathBuf,
@@ -1769,7 +1629,7 @@ fn create_evm_data_attestation(
abi_path: PathBuf,
witness_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_data_attestation(
settings_path,
sol_code_path,
@@ -1821,7 +1681,6 @@ fn create_evm_data_attestation(
output_source,
rpc_url=None,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: PathBuf,
@@ -1831,7 +1690,7 @@ fn setup_test_evm_witness(
output_source: PyTestDataSource,
rpc_url: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_witness(
data_path,
compiled_circuit_path,
@@ -1859,7 +1718,6 @@ fn setup_test_evm_witness(
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
#[gen_stub_pyfunction]
fn deploy_evm(
py: Python,
addr_path: PathBuf,
@@ -1869,7 +1727,7 @@ fn deploy_evm(
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::deploy_evm(
sol_code_path,
rpc_url,
@@ -1898,7 +1756,6 @@ fn deploy_evm(
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
#[gen_stub_pyfunction]
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
@@ -1909,7 +1766,7 @@ fn deploy_da_evm(
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::deploy_da_evm(
input_data,
settings_path,
@@ -1957,7 +1814,6 @@ fn deploy_da_evm(
addr_da = None,
addr_vk = None,
))]
#[gen_stub_pyfunction]
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
@@ -1980,7 +1836,7 @@ fn verify_evm<'a>(
None
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
.await
.map_err(|e| {
@@ -2030,7 +1886,6 @@ fn verify_evm<'a>(
srs_path=None,
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier_aggr(
py: Python,
aggregation_settings: Vec<PathBuf>,
@@ -2041,7 +1896,7 @@ fn create_evm_verifier_aggr(
srs_path: Option<PathBuf>,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_aggregate_verifier(
vk_path,
srs_path,
@@ -2061,19 +1916,15 @@ fn create_evm_verifier_aggr(
})
}
// Define a function to gather stub information.
define_stub_info_gatherer!(stub_info);
// Python Module
#[pymodule]
fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
pyo3_log::init();
m.add_class::<PyRunArgs>()?;
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_class::<PyCommitments>()?;
m.add_class::<PyInputType>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
@@ -2095,7 +1946,6 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
@@ -2113,48 +1963,3 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
Ok(())
}
impl pyo3_stub_gen::PyStubType for CalibrationTarget {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ProofType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for TranscriptType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for CheckMode {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ContractType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}

View File

@@ -141,11 +141,10 @@ pub(crate) fn gen_vk(
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(
&mut serialized_vk,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| {
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
})?;
Ok(serialized_vk)
}
@@ -166,7 +165,7 @@ pub(crate) fn gen_pk(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -198,7 +197,7 @@ pub(crate) fn verify(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -278,7 +277,7 @@ pub(crate) fn verify_aggr(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -366,7 +365,7 @@ pub(crate) fn prove(
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
@@ -488,7 +487,7 @@ pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -505,7 +504,7 @@ pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;

View File

@@ -8,7 +8,10 @@ use crate::{
Module,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
graph::{
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
GraphSettings,
},
};
use console_error_panic_hook;
use halo2_proofs::{
@@ -19,7 +22,6 @@ use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
};
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
@@ -111,15 +113,9 @@ pub fn feltToFloat(
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToFelt(
mut input: f64,
input: f64,
scale: crate::Scale,
input_type: &str,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
crate::circuit::InputType::roundtrip(
&crate::circuit::InputType::from_str(input_type)
.map_err(|e| JsError::new(&format!("{}", e)))?,
&mut input,
);
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);
@@ -228,7 +224,10 @@ pub fn poseidonHash(
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(

View File

@@ -8,11 +8,13 @@ pub mod poseidon_params;
pub mod spec;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_gadgets::poseidon::{
primitives::VariableLength, primitives::*, Hash, Pow5Chip, Pow5Config,
};
use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config};
use halo2_proofs::arithmetic::Field;
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::{circuit::*, plonk::*};
// use maybe_rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
use maybe_rayon::prelude::ParallelIterator;
use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
@@ -38,17 +40,22 @@ pub struct PoseidonConfig<const WIDTH: usize, const RATE: usize> {
pub pow5_config: Pow5Config<Fp, WIDTH, RATE>,
}
type InputAssignments = Vec<AssignedCell<Fp, Fp>>;
type InputAssignments = (Vec<AssignedCell<Fp, Fp>>, AssignedCell<Fp, Fp>);
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
#[derive(Debug, Clone)]
pub struct PoseidonChip<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> {
pub struct PoseidonChip<
S: Spec<Fp, WIDTH, RATE> + Sync,
const WIDTH: usize,
const RATE: usize,
const L: usize,
> {
config: PoseidonConfig<WIDTH, RATE>,
_marker: PhantomData<S>,
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
PoseidonChip<S, WIDTH, RATE>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
PoseidonChip<S, WIDTH, RATE, L>
{
/// Creates a new PoseidonChip
pub fn configure_with_cols(
@@ -75,8 +82,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
}
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
PoseidonChip<S, WIDTH, RATE>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
PoseidonChip<S, WIDTH, RATE, L>
{
/// Configuration of the PoseidonChip
pub fn configure_with_optional_instance(
@@ -93,6 +100,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
Self::configure_with_cols(
@@ -106,8 +116,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
}
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Module<Fp>
for PoseidonChip<S, WIDTH, RATE>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
Module<Fp> for PoseidonChip<S, WIDTH, RATE, L>
{
type Config = PoseidonConfig<WIDTH, RATE>;
type InputAssignments = InputAssignments;
@@ -142,6 +152,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
let instance = meta.instance_column();
@@ -163,10 +176,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, ModuleError> {
if message.len() != 1 {
return Err(ModuleError::InputWrongLength(message.len()));
}
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
@@ -176,81 +186,95 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
let res = layouter.assign_region(
|| "load message",
|mut region| {
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, _> = match &message {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
match &message {
ValTensor::Value { inner: v, .. } => {
v.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v)
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants.insert(
*f,
ValType::AssignedConstant(res.clone(), *f),
);
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
)),
}
})
.collect()
}
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
let offset = message.len() / WIDTH + 1;
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"AssignedValue".to_string(),
)),
}
})
.collect(),
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
let zero_val = region
.assign_advice_from_constant(
|| "",
self.config.hash_inputs[0],
offset,
Fp::ZERO,
)
.unwrap();
Ok(assigned_message?)
Ok((assigned_message?, zero_val))
},
);
log::trace!(
@@ -271,13 +295,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, ModuleError> {
let input_cells = self.layout_inputs(layouter, input, constants)?;
// empty hash case
if input_cells.is_empty() {
return Ok(input[0].clone());
}
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -285,25 +303,52 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
let start_time = instant::Instant::now();
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
// initialize the hasher
let hasher = Hash::<_, _, S, VariableLength, WIDTH, RATE>::init(
pow5_chip,
layouter.namespace(|| "block_hasher"),
)?;
let mut one_iter = false;
// do the Tree dance baby
while input_cells.len() > 1 || !one_iter {
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
.chunks(L)
.enumerate()
.map(|(i, block)| {
let _start_time = instant::Instant::now();
let hash: AssignedCell<Fp, Fp> = hasher.hash(
layouter.namespace(|| "hash"),
input_cells
.to_vec()
.try_into()
.map_err(|_| Error::Synthesis)?,
)?;
let mut block = block.to_vec();
let remainder = block.len() % L;
if remainder != 0 {
block.extend(vec![zero_val.clone(); L - remainder]);
}
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
// initialize the hasher
let hasher = Hash::<_, _, S, ConstantLength<L>, WIDTH, RATE>::init(
pow5_chip,
layouter.namespace(|| "block_hasher"),
)?;
let hash = hasher.hash(
layouter.namespace(|| "hash"),
block.to_vec().try_into().map_err(|_| Error::Synthesis)?,
);
if i == 0 {
log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed());
}
hash
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into());
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
one_iter = true;
input_cells = hashes?;
}
let duration = start_time.elapsed();
log::trace!("layout (N={:?}) took: {:?}", len, duration);
let result = Tensor::from(vec![ValType::from(hash.clone())].into_iter());
let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone())));
let output = match result[0].clone() {
ValType::PrevAssigned(v) => v,
@@ -342,59 +387,69 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
let len = message.len();
if len == 0 {
return Ok(vec![vec![]]);
}
let mut hash_inputs = message;
let len = hash_inputs.len();
let start_time = instant::Instant::now();
let hash = halo2_gadgets::poseidon::primitives::Hash::<
_,
S,
VariableLength,
{ WIDTH },
{ RATE },
>::init()
.hash(message);
let mut one_iter = false;
// do the Tree dance baby
while hash_inputs.len() > 1 || !one_iter {
let hashes: Vec<Fp> = hash_inputs
.par_chunks(L)
.map(|block| {
let mut block = block.to_vec();
let remainder = block.len() % L;
if remainder != 0 {
block.extend(vec![Fp::ZERO; L - remainder].iter());
}
let block_len = block.len();
let message = block
.try_into()
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
Ok(halo2_gadgets::poseidon::primitives::Hash::<
_,
S,
ConstantLength<L>,
{ WIDTH },
{ RATE },
>::init()
.hash(message))
})
.collect::<Result<Vec<_>, ModuleError>>()?;
one_iter = true;
hash_inputs = hashes;
}
let duration = start_time.elapsed();
log::trace!("run (N={:?}) took: {:?}", len, duration);
Ok(vec![vec![hash]])
Ok(vec![hash_inputs])
}
fn num_rows(input_len: usize) -> usize {
fn num_rows(mut input_len: usize) -> usize {
// this was determined by running the circuit and looking at the number of constraints
// in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope
// import numpy as np
// from scipy import stats
let fixed_cost: usize = 41 * L;
// x = np.array([32, 64, 96, 128, 160, 192])
// y = np.array([1298, 2594, 3890, 5186, 6482, 7778])
let mut num_rows = 0;
// slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
loop {
// the number of times the input_len is divisible by L
let num_chunks = input_len / L + 1;
num_rows += num_chunks * fixed_cost;
if num_chunks == 1 {
break;
}
input_len = num_chunks;
}
// print(f"slope: {slope}")
// print(f"intercept: {intercept}")
// print(f"R^2: {r_value**2}")
// # Predict for any x
// def predict(x):
// return slope * x + intercept
// # Test prediction
// test_x = 256
// print(f"Predicted value for x={test_x}: {predict(test_x)}")
// our output:
// slope: 40.5
// intercept: 2.0
// R^2: 1.0
// Predicted value for x=256: 10370.0
let fixed_cost: usize = 41 * input_len;
// the cost of the hash function is linear with the number of inputs
fixed_cost + 2
num_rows
}
}
@@ -421,12 +476,12 @@ mod tests {
const RATE: usize = POSEIDON_RATE;
const R: usize = 240;
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>> {
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>, const L: usize> {
message: ValTensor<Fp>,
_spec: PhantomData<S>,
}
impl<S: Spec<Fp, WIDTH, RATE>> Circuit<Fp> for HashCircuit<S> {
impl<S: Spec<Fp, WIDTH, RATE>, const L: usize> Circuit<Fp> for HashCircuit<S, L> {
type Config = PoseidonConfig<WIDTH, RATE>;
type FloorPlanner = ModulePlanner;
type Params = ();
@@ -442,7 +497,7 @@ mod tests {
}
fn configure(meta: &mut ConstraintSystem<Fp>) -> PoseidonConfig<WIDTH, RATE> {
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(meta, ())
PoseidonChip::<PoseidonSpec, WIDTH, RATE, L>::configure(meta, ())
}
fn synthesize(
@@ -450,7 +505,7 @@ mod tests {
config: PoseidonConfig<WIDTH, RATE>,
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> = PoseidonChip::new(config);
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(
&mut layouter,
&[self.message.clone()],
@@ -462,33 +517,18 @@ mod tests {
}
}
#[test]
fn poseidon_hash_empty() {
let message = [];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, vec![vec![]]).unwrap();
assert_eq!(prover.verify(), Ok(()))
}
#[test]
fn poseidon_hash() {
let rng = rand::rngs::OsRng;
let message = [Fp::random(rng), Fp::random(rng)];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec> {
let circuit = HashCircuit::<PoseidonSpec, 2> {
message: message.into(),
_spec: PhantomData,
};
@@ -501,13 +541,13 @@ mod tests {
let rng = rand::rngs::OsRng;
let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 3>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec> {
let circuit = HashCircuit::<PoseidonSpec, 3> {
message: message.into(),
_spec: PhantomData,
};
@@ -523,21 +563,23 @@ mod tests {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
for i in (32..128).step_by(32) {
{
let i = 32;
// print a bunch of new lines
log::info!(
println!(
"i is {} -------------------------------------------------",
i
);
let message: Vec<Fp> = (0..i).map(|_| Fp::random(rng)).collect::<Vec<_>>();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
let output =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, 32>::run(message.clone()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 17;
let circuit = HashCircuit::<PoseidonSpec> {
let circuit = HashCircuit::<PoseidonSpec, 32> {
message: message.into(),
_spec: PhantomData,
};
@@ -554,13 +596,13 @@ mod tests {
let mut message: Vec<Fp> = (0..2048).map(|_| Fp::random(rng)).collect::<Vec<_>>();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 25>::run(message.clone()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 17;
let circuit = HashCircuit::<PoseidonSpec> {
let circuit = HashCircuit::<PoseidonSpec, 25> {
message: message.into(),
_spec: PhantomData,
};

View File

@@ -17,6 +17,7 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsBoolean,
}
/// Matches a [BaseOp] to an operation over inputs
@@ -33,6 +34,7 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
}
@@ -72,6 +74,7 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -87,6 +90,7 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -102,6 +106,7 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsBoolean => 0,
}
}
@@ -117,6 +122,7 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsBoolean => 0,
}
}
}

View File

@@ -2,15 +2,16 @@ use std::str::FromStr;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
plonk::{ConstraintSystem, Constraints, Expression, Selector},
poly::Rotation,
};
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, IntoPy},
conversion::{FromPyObject, PyTryFrom},
exceptions::PyValueError,
prelude::*,
types::PyString,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -75,16 +76,6 @@ impl FromStr for CheckMode {
}
}
impl CheckMode {
/// Returns the value of the check mode
pub fn is_safe(&self) -> bool {
match self {
CheckMode::SAFE => true,
CheckMode::UNSAFE => false,
}
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
@@ -148,9 +139,10 @@ impl IntoPy<PyObject> for CheckMode {
#[cfg(feature = "python-bindings")]
/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python)
impl<'source> FromPyObject<'source> for CheckMode {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
match trystr.to_lowercase().as_str() {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"safe" => Ok(CheckMode::SAFE),
"unsafe" => Ok(CheckMode::UNSAFE),
_ => Err(PyValueError::new_err("Invalid value for CheckMode")),
@@ -169,8 +161,8 @@ impl IntoPy<PyObject> for Tolerance {
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok((val, scale)) = ob.extract::<(f32, f32)>() {
Ok(Tolerance {
val,
scale: utils::F32(scale),
@@ -215,16 +207,15 @@ impl DynamicLookups {
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub output_selectors: Vec<Selector>,
pub reference_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub outputs: Vec<VarTensor>,
pub references: Vec<VarTensor>,
}
impl Shuffles {
@@ -235,13 +226,9 @@ impl Shuffles {
Self {
input_selectors: BTreeMap::new(),
output_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
outputs: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
}
}
}
@@ -341,8 +328,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
/// shared table inputs
pub shared_table_inputs: Vec<TableColumn>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
@@ -355,7 +340,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
@@ -382,18 +366,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
if inputs[0].num_cols() != output.num_cols() {
log::warn!("input and output shapes do not match");
}
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
log::warn!("input number of inner columns do not match");
}
if inputs[0].num_inner_cols() != output.num_inner_cols() {
log::warn!("input and output number of inner columns do not match");
}
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -427,13 +406,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.query_offset_rng();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let constraints = match base_op {
BaseOp::IsBoolean => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
let output = expected_output[base_op.constraint_idx()].clone();
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
}
};
Constraints::with_selector(selector, constraints)
@@ -488,7 +478,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
check_mode,
_marker: PhantomData,
}
@@ -519,9 +508,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
let table = if !self.static_lookups.tables.contains_key(nl) {
let table =
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
logrows,
nl,
Some(table.table_inputs.clone()),
)
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
@@ -572,9 +573,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// this is 0 if the index is the same as the column index (starting from 1)
let col_expr = sel.clone()
* (table
* table
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel));
.get_expr_at_idx(col_idx, synthetic_sel);
let multiplier =
table.selector_constructor.get_selector_val_at_idx(col_idx);
@@ -606,40 +607,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
@@ -765,8 +732,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 3],
outputs: &[VarTensor; 3],
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), CircuitError>
where
F: Field,
@@ -777,14 +744,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
}
for t in outputs.iter() {
for t in references.iter() {
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of blocks
if outputs
if references
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
@@ -792,23 +759,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"outputs inner cols".to_string(),
"references inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
for q in 0..outputs[0].num_blocks() {
let s_output = cs.complex_selector();
for q in 0..references[0].num_blocks() {
let s_reference = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("shuffle", |cs| {
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_outputq = cs.query_selector(s_output);
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
for input in inputs {
@@ -820,9 +787,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
});
}
let mut output_queries = vec![one.clone()];
for output in outputs {
output_queries.push(match output {
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
@@ -831,7 +798,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
expression
@@ -842,13 +809,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.or_insert(s_input);
}
}
self.shuffles.output_selectors.push(s_output);
self.shuffles.reference_selectors.push(s_reference);
}
// if we haven't previously initialized the input/output, do so now
if self.shuffles.outputs.is_empty() {
debug!("assigning shuffles output");
self.shuffles.outputs = outputs.to_vec();
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
@@ -880,6 +847,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
@@ -917,9 +885,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let default_x = range_check.get_first_element(col_idx);
let col_expr = sel.clone()
* (range_check
* range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel));
.get_expr_at_idx(col_idx, synthetic_sel);
let multiplier = range_check
.selector_constructor
@@ -942,40 +910,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);

View File

@@ -25,7 +25,7 @@ pub enum CircuitError {
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
/// Invalid einsum expression
///
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
@@ -94,19 +94,4 @@ pub enum CircuitError {
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
#[error("invalid input type {0}")]
/// Invalid input type
InvalidInputType(String),
#[error("an element is missing from the shuffled version of the tensor")]
/// An element is missing from the shuffled version of the tensor
MissingShuffleElement,
/// Visibility has not been set
#[error("visibility has not been set")]
UnsetVisibility,
/// A decomposition base overflowed
#[error("decomposition base overflowed")]
DecompositionBaseOverflow,
}

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::integer_rep_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -13,38 +13,14 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Ln {
scale: utils::F32,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
},
Floor {
scale: utils::F32,
legs: usize,
},
Round {
scale: utils::F32,
legs: usize,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
use_range_check_for_int: bool,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
@@ -69,17 +45,12 @@ pub enum HybridOp {
ReduceArgMin {
dim: usize,
},
Max,
Min,
Softmax {
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
Output {
tol: Tolerance,
decomp: bool,
},
RangeCheck(Tolerance),
Greater,
GreaterEqual,
Less,
@@ -108,8 +79,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
vec![0, 1]
}
@@ -124,32 +93,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn as_string(&self) -> String {
match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
HybridOp::Max => "MAX".to_string(),
HybridOp::Min => "MIN".to_string(),
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => format!(
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
input_scale, output_scale, use_range_check_for_int
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
@@ -181,9 +139,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale, output_scale, axes
)
}
HybridOp::Output { tol, decomp } => {
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
@@ -206,34 +162,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Floor { scale, legs } => {
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Round { scale, legs } => {
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
HybridOp::SumPool {
padding,
stride,
@@ -251,20 +179,42 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
layouts::div(
use_range_check_for_int,
} => {
if input_scale.0.fract() == 0.0
&& output_scale.0.fract() == 0.0
&& *use_range_check_for_int
{
layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(denom.0 as IntegerRep),
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Recip {
input_scale: *input_scale,
output_scale: *output_scale,
},
)?
}
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
@@ -319,13 +269,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*output_scale,
axes,
)?,
HybridOp::Output { tol, decomp } => layouts::output(
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
values[..].try_into()?,
tol.scale,
tol.val,
*decomp,
)?,
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
HybridOp::GreaterEqual => {
@@ -352,18 +301,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Softmax {
output_scale,
input_scale,
..
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -14,27 +15,101 @@ use halo2curves::ff::PrimeField;
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Div { denom: utils::F32 },
IsOdd,
PowersOfTwo { scale: utils::F32 },
Ln { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Exp { scale: utils::F32, base: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
ACosh { scale: utils::F32 },
Sin { scale: utils::F32 },
ASin { scale: utils::F32 },
Sinh { scale: utils::F32 },
ASinh { scale: utils::F32 },
Tan { scale: utils::F32 },
ATan { scale: utils::F32 },
Tanh { scale: utils::F32 },
ATanh { scale: utils::F32 },
Erf { scale: utils::F32 },
Pow { scale: utils::F32, a: utils::F32 },
HardSwish { scale: utils::F32 },
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
Max {
scale: utils::F32,
a: utils::F32,
},
Min {
scale: utils::F32,
a: utils::F32,
},
Ceil {
scale: utils::F32,
},
Floor {
scale: utils::F32,
},
Round {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
Rsqrt {
scale: utils::F32,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Ln {
scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Cos {
scale: utils::F32,
},
ACos {
scale: utils::F32,
},
Cosh {
scale: utils::F32,
},
ACosh {
scale: utils::F32,
},
Sin {
scale: utils::F32,
},
ASin {
scale: utils::F32,
},
Sinh {
scale: utils::F32,
},
ASinh {
scale: utils::F32,
},
Tan {
scale: utils::F32,
},
ATan {
scale: utils::F32,
},
Tanh {
scale: utils::F32,
},
ATanh {
scale: utils::F32,
},
Erf {
scale: utils::F32,
},
KroneckerDelta,
Pow {
scale: utils::F32,
a: utils::F32,
},
HardSwish {
scale: utils::F32,
},
}
impl LookupOp {
@@ -48,14 +123,27 @@ impl LookupOp {
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
LookupOp::Floor { scale } => format!("floor_{}", scale),
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
LookupOp::IsOdd => "is_odd".to_string(),
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale, base } => format!("exp_{}_{}", scale, base),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
@@ -80,28 +168,65 @@ impl LookupOp {
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res =
match &self {
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
LookupOp::Ceil { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
}
LookupOp::PowersOfTwo { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
LookupOp::Floor { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
}
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
LookupOp::Round { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
}
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::KroneckerDelta => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
}
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
),
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Cast { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::LeakyReLU { slope: a } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
LookupOp::Sqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
}
LookupOp::Rsqrt { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
}
LookupOp::Erf { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
}
LookupOp::Exp { scale, base } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::exp(&x, scale.into(), base.into()),
),
LookupOp::Exp { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
}
LookupOp::Ln { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
}
LookupOp::Cos { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
}
@@ -158,14 +283,30 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
LookupOp::IsOdd => "IS_ODD".to_string(),
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Recip {
input_scale,
output_scale,
} => format!(
"RECIP(input_scale={}, output_scale={})",
input_scale, output_scale
),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
LookupOp::Exp { scale, base } => format!("EXP(scale={}, base={})", scale, base),
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
LookupOp::Tanh { scale } => format!("TANH(scale={})", scale),
@@ -198,7 +339,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = inputs_scale[0];
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
Ok(scale)
}

View File

@@ -105,10 +105,7 @@ impl InputType {
}
///
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone + std::fmt::Debug>(
&self,
input: &mut T,
) {
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
match self {
InputType::Bool => {
let boolean_input = input.clone().to_i64().unwrap();
@@ -121,7 +118,7 @@ impl InputType {
*input = T::from_f32(f32_input).unwrap();
}
InputType::F32 => {
let f32_input: f32 = input.clone().to_f32().unwrap();
let f32_input = input.clone().to_f32().unwrap();
*input = T::from_f32(f32_input).unwrap();
}
InputType::F64 => {
@@ -136,22 +133,6 @@ impl InputType {
}
}
impl std::str::FromStr for InputType {
type Err = CircuitError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"bool" => Ok(InputType::Bool),
"f16" => Ok(InputType::F16),
"f32" => Ok(InputType::F32),
"f64" => Ok(InputType::F64),
"int" => Ok(InputType::Int),
"tdim" => Ok(InputType::TDim),
e => Err(CircuitError::InvalidInputType(e.to_string())),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {
@@ -159,8 +140,6 @@ pub struct Input {
pub scale: crate::Scale,
///
pub datum_type: InputType,
/// decomp check
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
@@ -198,7 +177,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
}
} else {
@@ -254,26 +232,20 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
#[serde(skip)]
pub pre_assigned_val: Option<ValTensor<F>>,
///
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
quantized_values,
raw_values,
pre_assigned_val: None,
decomp,
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = match self.quantized_values.visibility() {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
let visibility = self.quantized_values.visibility().unwrap();
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
}
@@ -317,12 +289,7 @@ impl<
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
Ok(Some(layouts::identity(config, region, &[value])?))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -1,8 +1,5 @@
use crate::{
circuit::{
layouts,
utils::{self, F32},
},
circuit::layouts,
tensor::{self, Tensor, TensorError},
};
@@ -12,12 +9,9 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
ReLU,
Abs,
Sign,
LeakyReLU {
slope: utils::F32,
scale: i32,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
@@ -118,9 +112,9 @@ impl<
fn as_string(&self) -> String {
match &self {
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
PolyOp::Abs => "ABS".to_string(),
PolyOp::Sign => "SIGN".to_string(),
PolyOp::ReLU => "RELU".to_string(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
@@ -204,9 +198,7 @@ impl<
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
@@ -252,12 +244,6 @@ impl<
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 1 {
return Err(TensorError::DimError(
"GatherElements only accepts single inputs".to_string(),
)
.into());
}
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
@@ -275,12 +261,6 @@ impl<
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 2 {
return Err(TensorError::DimError(
"ScatterElements requires two inputs".to_string(),
)
.into());
}
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
@@ -323,9 +303,7 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -351,12 +329,6 @@ impl<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
// this corresponds to the relu operation
PolyOp::LeakyReLU {
slope: F32(0.0), ..
} => in_scales[0],
// this corresponds to the leaky relu operation with a slope which induces a change in scale
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],

View File

@@ -211,7 +211,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green(),
self.max_dynamic_input_len().to_string().green()
);
}
@@ -474,17 +474,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
Ok(())
}
/// Update the max and min forcefully
pub fn update_max_min_lookup_inputs_force(
&mut self,
min: IntegerRep,
max: IntegerRep,
) -> Result<(), CircuitError> {
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {
@@ -611,6 +600,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), CircuitError> {
self.update_max_dynamic_input_len(values.len());
if let Some(region) = &self.region {
@@ -671,47 +661,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_constrained(
pub fn assign_with_duplication(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_constrained(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
Ok((res, len))
@@ -720,7 +685,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.row,
self.linear_coord,
values,
true,
single_inner_col,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))

View File

@@ -132,29 +132,30 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
(first_element, op_f.output[0])
}
/// calculates the column size given the number of rows and reserved blinding rows
///
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / col_size as IntegerRep) as usize + 1
(range_len / (col_size as IntegerRep)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get largest element represented by the range
pub fn largest(&self) -> IntegerRep {
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
}
fn name(&self) -> String {
format!(
"{}_{}_{}",
self.nonlinearity.as_path(),
self.range.0,
self.largest()
self.range.1
)
}
/// Configures the table.
@@ -163,7 +164,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
range: Range,
logrows: usize,
nonlinearity: &LookupOp,
preexisting_inputs: &mut Vec<TableColumn>,
preexisting_inputs: Option<Vec<TableColumn>>,
) -> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
@@ -172,28 +173,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
debug!("table range: {:?}", range);
// validate enough columns are provided to store the range
if preexisting_inputs.len() < num_cols {
// add columns to match the required number of columns
let diff = num_cols - preexisting_inputs.len();
for _ in 0..diff {
preexisting_inputs.push(cs.lookup_table_column());
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
}
}
cols
});
let num_cols = table_inputs.len();
let num_cols = preexisting_inputs.len();
if num_cols > 1 {
warn!("Using {} columns for non-linearity table.", num_cols);
}
let table_outputs = preexisting_inputs
let table_outputs = table_inputs
.iter()
.map(|_| cs.lookup_table_column())
.collect::<Vec<_>>();
Table {
nonlinearity: nonlinearity.clone(),
table_inputs: preexisting_inputs.clone(),
table_inputs,
table_outputs,
is_assigned: false,
selector_constructor: SelectorConstructor::new(num_cols),
@@ -221,7 +222,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}
let smallest = self.range.0;
let largest = self.largest();
let largest = self.range.1;
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
@@ -290,7 +291,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
if !preassigned_input {
table.assign_cell(
|| format!("nl_i_col row {}", row_offset),
@@ -350,11 +350,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
}
/// calculates the column size
///
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in

View File

@@ -1040,10 +1040,6 @@ mod conv {
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
// column for constants
let _constant = VarTensor::constant_cols(cs, K, 8, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1175,7 +1171,7 @@ mod conv_col_ultra_overflow {
use super::*;
const K: usize = 6;
const K: usize = 4;
const LEN: usize = 10;
#[derive(Clone)]
@@ -1195,10 +1191,9 @@ mod conv_col_ultra_overflow {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1384,10 +1379,7 @@ mod conv_relu_col_ultra_overflow {
.layout(
&mut region,
&[output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
Box::new(PolyOp::ReLU),
)
.unwrap();
Ok(())
@@ -1781,18 +1773,13 @@ mod shuffle {
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
config
}
@@ -1813,7 +1800,6 @@ mod shuffle {
&mut region,
&self.inputs[i],
&self.references[i],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1999,7 +1985,7 @@ mod add_with_overflow_and_poseidon {
let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE);
VarTensor::constant_cols(cs, K, 2, false);
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(cs, ());
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::configure(cs, ());
MyCircuitConfig { base, poseidon }
}
@@ -2009,7 +1995,7 @@ mod add_with_overflow_and_poseidon {
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> =
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a =
@@ -2044,9 +2030,11 @@ mod add_with_overflow_and_poseidon {
let b = (0..LEN)
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
.collect::<Vec<_>>();
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0];
let commitment_a =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone()).unwrap()[0][0];
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0];
let commitment_b =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone()).unwrap()[0][0];
// parameters
let a = Tensor::from(a.into_iter().map(Value::known));
@@ -2068,11 +2056,13 @@ mod add_with_overflow_and_poseidon {
let b = (0..LEN)
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
.collect::<Vec<_>>();
let commitment_a =
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0] + Fr::one();
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone())
.unwrap()[0][0]
+ Fr::one();
let commitment_b =
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0] + Fr::one();
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone())
.unwrap()[0][0]
+ Fr::one();
// parameters
let a = Tensor::from(a.into_iter().map(Value::known));
@@ -2357,14 +2347,7 @@ mod matmul_relu {
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
.unwrap();
Ok(())
},
@@ -2456,14 +2439,7 @@ mod relu {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
Ok(config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.unwrap())
},
)
@@ -2506,11 +2482,11 @@ mod lookup_ultra_overflow {
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
#[derive(Clone)]
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
}
impl Circuit<F> for SigmoidCircuit<F> {
impl Circuit<F> for ReLUCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -2524,7 +2500,7 @@ mod lookup_ultra_overflow {
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
.collect::<Vec<_>>();
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let mut config = BaseConfig::default();
@@ -2557,7 +2533,7 @@ mod lookup_ultra_overflow {
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.map_err(|_| Error::Synthesis)
},
@@ -2570,13 +2546,13 @@ mod lookup_ultra_overflow {
#[test]
#[ignore]
fn sigmoidcircuit() {
fn relucircuit() {
// get some logs fam
crate::logger::init_logger();
// parameters
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
let circuit = SigmoidCircuit::<F> {
let circuit = ReLUCircuit::<F> {
input: ValTensor::from(a),
};
@@ -2586,7 +2562,7 @@ mod lookup_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
SigmoidCircuit<F>,
ReLUCircuit<F>,
>(&circuit, &params, true)
.unwrap();

View File

@@ -2,7 +2,12 @@ use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
exceptions::PyValueError,
prelude::*,
types::PyString,
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::str::FromStr;
@@ -83,15 +88,16 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VK abi path
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
/// Default commitment
pub const DEFAULT_COMMITMENT: &str = "kzg";
/// Default seed used to generate random data
pub const DEFAULT_SEED: &str = "21242";
// TODO: In prod this will be the same across all chains we deploy to using the EZKL multisig create2 deployment.
/// Default address of the verifier manager.
pub const DEFAULT_VERIFIER_MANAGER_ADDRESS: &str = "0xdc64a140aa3e981100a9beca4e685f962f0cf6c9";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
@@ -106,8 +112,8 @@ impl IntoPy<PyObject> for TranscriptType {
#[cfg(feature = "python-bindings")]
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
impl<'source> FromPyObject<'source> for TranscriptType {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"poseidon" => Ok(TranscriptType::Poseidon),
@@ -184,16 +190,20 @@ pub enum ContractType {
/// Deploys a verifier contrat tailored to the circuit and not reusable
Verifier {
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
reusable: bool,
},
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
VerifyingKeyArtifact,
/// Manages the deployments of all reusable verifier and verifying artifact keys. Routes all the verification tx to the correct artifacts.
VerifierManager
}
impl Default for ContractType {
fn default() -> Self {
ContractType::Verifier { reusable: false }
ContractType::Verifier {
reusable: false,
}
}
}
@@ -205,9 +215,12 @@ impl std::fmt::Display for ContractType {
match self {
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_string()
}
ContractType::Verifier { reusable: false } => "verifier".to_string(),
},
ContractType::Verifier {
reusable: false,
} => "verifier".to_string(),
ContractType::VerifyingKeyArtifact => "vka".to_string(),
ContractType::VerifierManager => "manager".to_string()
}
)
}
@@ -225,11 +238,12 @@ impl From<&str> for ContractType {
"verifier" => ContractType::Verifier { reusable: false },
"verifier/reusable" => ContractType::Verifier { reusable: true },
"vka" => ContractType::VerifyingKeyArtifact,
"manager" => ContractType::VerifierManager,
_ => {
log::error!("Invalid value for ContractType");
log::warn!("Defaulting to verifier");
ContractType::default()
}
},
}
}
}
@@ -279,8 +293,9 @@ impl IntoPy<PyObject> for CalibrationTarget {
#[cfg(feature = "python-bindings")]
/// Obtains CalibrationTarget from PyObject (Required for CalibrationTarget to be compatible with Python)
impl<'source> FromPyObject<'source> for CalibrationTarget {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"resources" => Ok(CalibrationTarget::Resources {
col_overflow: false,
@@ -297,8 +312,12 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
impl IntoPy<PyObject> for ContractType {
fn into_py(self, py: Python) -> PyObject {
match self {
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_object(py)
}
ContractType::Verifier {
reusable: false,
} => "verifier".to_object(py),
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
}
}
@@ -307,16 +326,31 @@ impl IntoPy<PyObject> for ContractType {
#[cfg(feature = "python-bindings")]
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
impl<'source> FromPyObject<'source> for ContractType {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"verifier" => Ok(ContractType::Verifier { reusable: false }),
"verifier" => Ok(ContractType::Verifier {
reusable: false,
}),
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
"vka" => Ok(ContractType::VerifyingKeyArtifact),
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
lazy_static! {
/// The version of the ezkl library
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
"source - no compatibility guaranteed"
} else {
env!("CARGO_PKG_VERSION")
};
}
/// Get the styles for the CLI
pub fn get_styles() -> clap::builder::Styles {
@@ -325,49 +359,49 @@ pub fn get_styles() -> clap::builder::Styles {
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Cyan,
))),
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.header(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Cyan,
))),
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.literal(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
)
.invalid(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.error(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.literal(clap::builder::styling::Style::new().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta),
)))
.invalid(clap::builder::styling::Style::new().bold().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
)))
.error(clap::builder::styling::Style::new().bold().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
)))
.valid(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Green,
))),
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
)
.placeholder(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
)
.placeholder(clap::builder::styling::Style::new().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White),
)))
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
pub struct Cli {
/// If provided, outputs the completion file for given shell
#[clap(long = "generate", value_parser)]
@@ -377,6 +411,7 @@ pub struct Cli {
pub command: Option<Commands>,
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -424,23 +459,9 @@ pub enum Commands {
#[clap(flatten)]
args: RunArgs,
},
/// Generate random data for a model
GenRandomData {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
variables: Vec<(String, usize)>,
/// random seed for reproducibility (optional)
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
seed: u64,
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
@@ -471,6 +492,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long, value_hint = clap::ValueHint::Other)]
max_logrows: Option<u32>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
only_range_check_rebase: Option<bool>,
},
/// Generates a dummy SRS
@@ -487,10 +511,10 @@ pub enum Commands {
commitment: Option<Commitments>,
},
/// Gets an SRS from a circuit settings file.
/// Gets an SRS from a circuit settings file.
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
@@ -537,7 +561,7 @@ pub enum Commands {
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
@@ -564,7 +588,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
#[arg(
@@ -572,7 +596,7 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum,
value_enum,
value_hint = clap::ValueHint::Other
)]
transcript: TranscriptType,
@@ -606,7 +630,7 @@ pub enum Commands {
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to output the verification key file to
@@ -622,7 +646,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -646,7 +670,7 @@ pub enum Commands {
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
@@ -659,7 +683,7 @@ pub enum Commands {
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
/// Swaps the positions in the transcript that correspond to commitments
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
@@ -669,7 +693,7 @@ pub enum Commands {
witness_path: Option<PathBuf>,
},
/// Loads model, data, and creates proof
/// Loads model, data, and creates proof
Prove {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
@@ -683,7 +707,7 @@ pub enum Commands {
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
@@ -691,7 +715,7 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = ProofType::Single,
value_enum,
value_enum,
value_hint = clap::ValueHint::Other
)]
proof_type: ProofType,
@@ -699,7 +723,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
},
/// Encodes a proof into evm calldata
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
@@ -712,10 +736,10 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
/// Creates an Evm verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -734,10 +758,10 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
reusable: Option<bool>,
},
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -753,7 +777,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -777,10 +801,10 @@ pub enum Commands {
witness: Option<PathBuf>,
},
/// Creates an Evm verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
@@ -813,7 +837,7 @@ pub enum Commands {
/// The path to the verification key file (generated using the setup command)
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
@@ -831,7 +855,7 @@ pub enum Commands {
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
@@ -841,7 +865,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
DeployEvm {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
@@ -858,11 +882,19 @@ pub enum Commands {
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
/// Deployed verifier manager contract's address
/// Used to facilitate reusable verifier and vk artifact deployment
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_verifier_manager: Option<H160Flag>,
/// Deployed reusable verifier contract's address
/// Use to facilitate reusable verifier and vk artifact deployment
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_reusable_verifier: Option<H160Flag>,
/// Contract type to be deployed
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
contract: ContractType,
},
/// Deploys an evm verifier that allows for data attestation
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -887,7 +919,7 @@ pub enum Commands {
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
/// Verifies a proof using a local Evm executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
@@ -915,6 +947,7 @@ pub enum Commands {
},
}
impl Commands {
/// Converts the commands to a json string
pub fn as_json(&self) -> String {
@@ -925,4 +958,4 @@ impl Commands {
pub fn from_json(json: &str) -> Self {
serde_json::from_str(json).unwrap()
}
}
}

File diff suppressed because one or more lines are too long

View File

@@ -1,13 +1,10 @@
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
fix_da_single_sol,
};
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[allow(unused_imports)]
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
use crate::graph::input::{Calls, GraphData};
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
use crate::graph::input::GraphData;
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
@@ -65,8 +62,6 @@ use std::str::FromStr;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
use tract_onnx::prelude::IntoTensor;
use tract_onnx::prelude::Tensor as TractTensor;
use lazy_static::lazy_static;
@@ -118,7 +113,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
} => gen_srs_cmd(
srs_path,
logrows as u32,
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
),
Commands::GetSrs {
srs_path,
@@ -136,17 +131,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
args,
),
Commands::GenRandomData {
model,
data,
variables,
seed,
} => gen_random_data(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
variables,
seed,
),
Commands::CalibrateSettings {
model,
settings_path,
@@ -156,6 +140,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
scales,
scale_rebase_multiplier,
max_logrows,
only_range_check_rebase,
} => calibrate(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
@@ -164,6 +149,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
max_logrows,
)
.await
@@ -424,24 +410,46 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
Commands::DeployEvm {
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
addr_verifier_manager,
addr_reusable_verifier,
contract,
} => {
// if contract type is either verifier/reusable
match contract {
ContractType::Verifier { reusable: true } => {
if addr_verifier_manager.is_none() {
panic!("Must pass a verifier manager address for reusable verifier")
}
}
ContractType::VerifyingKeyArtifact => {
if addr_verifier_manager.is_none() || addr_reusable_verifier.is_none() {
panic!(
"Must pass a verifier manager address and reusable verifier address for verifying key artifact"
)
}
}
_ => {}
};
deploy_evm(
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
rpc_url,
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
optimizer_runs,
private_key,
addr_verifier_manager.map(|s| s.into()),
addr_reusable_verifier.map(|s| s.into()),
contract,
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::DeployEvmDataAttestation {
data,
settings_path,
@@ -691,15 +699,11 @@ pub(crate) async fn get_srs_cmd(
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
let mut file = std::fs::File::create(&computed_srs_path)?;
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
params.write(&mut buffer)?;
info!(
"Saved SRS to {}.",
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
);
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
@@ -841,71 +845,6 @@ pub(crate) fn gen_circuit_settings(
Ok(String::new())
}
/// Generate a circuit settings file
pub(crate) fn gen_random_data(
model_path: PathBuf,
data_path: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
) -> Result<String, EZKLError> {
let mut file = std::fs::File::open(&model_path).map_err(|e| {
crate::graph::errors::GraphError::ReadWriteFileError(
model_path.display().to_string(),
e.to_string(),
)
})?;
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
let input_facts = tract_model
.input_outlets()
.map_err(|e| EZKLError::from(e.to_string()))?
.iter()
.map(|&i| tract_model.outlet_fact(i))
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
.map_err(|e| EZKLError::from(e.to_string()))?;
/// Generates a random tensor of a given size and type.
fn random(
sizes: &[usize],
datum_type: tract_onnx::prelude::DatumType,
seed: u64,
) -> TractTensor {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.gen());
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
if let Some(value) = &fact.konst {
return value.clone().into_tensor();
}
random(
fact.shape
.as_concrete()
.expect("Expected concrete shape, found: {fact:?}"),
fact.datum_type,
seed,
)
}
let generated = input_facts
.iter()
.map(|v| tensor_for_fact(v, seed))
.collect_vec();
let data = GraphData::from_tract_data(&generated)?;
data.save(data_path)?;
Ok(String::new())
}
// not for wasm targets
pub(crate) fn init_spinner() -> ProgressBar {
let pb = indicatif::ProgressBar::new_spinner();
@@ -1050,6 +989,7 @@ pub(crate) async fn calibrate(
lookup_safety_margin: f64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, EZKLError> {
use log::error;
@@ -1085,6 +1025,12 @@ pub(crate) async fn calibrate(
(11..14).collect::<Vec<crate::Scale>>()
};
let div_rebasing = if only_range_check_rebase {
vec![false]
} else {
vec![true, false]
};
let mut found_params: Vec<GraphSettings> = vec![];
// 2 x 2 grid
@@ -1122,6 +1068,12 @@ pub(crate) async fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
let mut forward_pass_res = HashMap::new();
let pb = init_bar(range_grid.len() as u64);
@@ -1130,23 +1082,30 @@ pub(crate) async fn calibrate(
let mut num_failed = 0;
let mut num_passed = 0;
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (input_scale, param_scale, scale_rebase_multiplier);
let key = (
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
);
forward_pass_res.insert(key, vec![]);
let local_run_args = RunArgs {
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
..settings.run_args.clone()
};
@@ -1250,6 +1209,7 @@ pub(crate) async fn calibrate(
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
@@ -1357,6 +1317,7 @@ pub(crate) async fn calibrate(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
best_params.run_args.div_rebasing,
))
.ok_or("no params found")?
.iter()
@@ -1478,6 +1439,7 @@ pub(crate) async fn create_evm_verifier(
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn create_evm_vka(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
@@ -1506,9 +1468,20 @@ pub(crate) async fn create_evm_vka(
num_instance,
);
let vk_solidity = generator.render_separately()?.1;
let (reusable_verifier, vk_solidity) = generator.render_separately()?;
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
// Remove the first line of vk_solidity (license identifier). Same license identifier for all contracts in this .sol
let vk_solidity = vk_solidity
.lines()
.skip(1)
.collect::<Vec<&str>>()
.join("\n");
// We store each contracts to the same file...
// We need to do this so that during the deployment transaction we make sure
// verifier manager links the VKA to the correct reusable_verifier.
let combined_solidity = format!("{}\n\n{}", reusable_verifier, vk_solidity);
File::create(sol_code_path.clone())?.write_all(combined_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
@@ -1535,27 +1508,15 @@ pub(crate) async fn create_evm_data_attestation(
trace!("params computed");
// if input is not provided, we just instantiate dummy input data
let data =
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
let mut output_len = None;
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
if visibility.output.is_private() {
return Err("private output data on chain is not supported on chain".into());
}
let mut on_chain_output_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_output_data.push(call);
}
}
Calls::Single(call) => {
output_len = Some(call.len);
}
for call in source.calls {
on_chain_output_data.push(call);
}
Some(on_chain_output_data)
} else {
@@ -1567,15 +1528,8 @@ pub(crate) async fn create_evm_data_attestation(
return Err("private input data on chain is not supported on chain".into());
}
let mut on_chain_input_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_input_data.push(call);
}
}
Calls::Single(call) => {
input_len = Some(call.len);
}
for call in source.calls {
on_chain_input_data.push(call);
}
Some(on_chain_input_data)
} else {
@@ -1602,24 +1556,13 @@ pub(crate) async fn create_evm_data_attestation(
None
};
// if either input_len or output_len is Some then we are in the single call data attestation mode
if input_len.is_some() || output_len.is_some() {
let output = fix_da_single_sol(input_len, output_len)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
} else {
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
}
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
Ok(String::new())
}
@@ -1656,21 +1599,51 @@ pub(crate) async fn deploy_evm(
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
verifier_manager: Option<alloy::primitives::Address>,
reusable_verifier: Option<alloy::primitives::Address>,
contract: ContractType,
) -> Result<String, EZKLError> {
use crate::eth::{deploy_reusable_verifier, deploy_vka};
let contract_name = match contract {
ContractType::Verifier { reusable: false } => "Halo2Verifier",
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
ContractType::VerifierManager => "EZKLVerifierManager",
};
let contract_address = if contract_name == "Halo2VerifierReusable" {
// Use VerifierManager to deploy the contract
deploy_reusable_verifier(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
verifier_manager.unwrap(),
)
.await?
} else if contract_name == "Halo2VerifyingArtifact" {
deploy_vka(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
verifier_manager.unwrap(),
reusable_verifier.unwrap(),
)
.await?
} else {
deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
)
.await?
};
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
)
.await?;
info!("Contract deployed at: {:#?}", contract_address);
@@ -2127,7 +2100,6 @@ pub(crate) fn mock_aggregate(
Ok(String::new())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,

View File

@@ -5,12 +5,10 @@ use halo2curves::ff::PrimeField;
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an integer rep to a PrimeField element.
/// Converts an i64 to a PrimeField element.
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else if x == IntegerRep::MIN {
-F::from_u128(x.saturating_neg() as u128) - F::ONE
} else {
-F::from_u128(x.saturating_neg() as u128)
}
@@ -34,9 +32,6 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
/// Converts a PrimeField element to an i64.
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
return IntegerRep::MIN;
}
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -56,7 +51,7 @@ mod test {
use halo2curves::pasta::Fp as F;
#[test]
fn integerreptofelt() {
fn test_conv() {
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
@@ -74,24 +69,8 @@ mod test {
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttointegerrepmin() {
let x = IntegerRep::MIN;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
#[test]
fn felttointegerrepmax() {
let x = IntegerRep::MAX;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -11,12 +11,6 @@ pub enum GraphError {
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Non scalar power
#[error("we only support scalar powers")]
NonScalarPower,
/// Non scalar base for exponentiation
#[error("we only support scalar bases for exponentiation")]
NonScalarBase,
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
@@ -119,13 +113,13 @@ pub enum GraphError {
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
/// Ranges can only be constant
///
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
/// Trilu diagonal must be constant
///
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
/// The witness was too short
///
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
@@ -149,13 +143,4 @@ pub enum GraphError {
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
/// Only nearest neighbor interpolation is supported
#[error("only nearest neighbor interpolation is supported")]
InvalidInterpolation,
/// Node has a missing output
#[error("node {0} has a missing output")]
MissingOutput(usize),
/// Inssuficient advice columns
#[error("insuficcient advice columns (need {0} at least)")]
InsufficientAdviceColumns(usize),
}

View File

@@ -14,6 +14,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -24,7 +25,6 @@ use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
@@ -32,95 +32,30 @@ type Decimals = u8;
type Call = String;
type RPCUrl = String;
/// Represents different types of values that can be stored in a file source
/// Used for handling various input types in zero-knowledge proofs
///
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum FileSourceInner {
/// Floating point value (64-bit)
/// Inner elements of float inputs coming from a file
Float(f64),
/// Boolean value
/// Inner elements of bool inputs coming from a file
Bool(bool),
/// Field element value for direct use in circuits
/// Inner elements of inputs coming from a witness
Field(Fp),
}
impl FileSourceInner {
/// Returns true if the value is a floating point number
///
pub fn is_float(&self) -> bool {
matches!(self, FileSourceInner::Float(_))
}
/// Returns true if the value is a boolean
///
pub fn is_bool(&self) -> bool {
matches!(self, FileSourceInner::Bool(_))
}
/// Returns true if the value is a field element
///
pub fn is_field(&self) -> bool {
matches!(self, FileSourceInner::Field(_))
}
/// Creates a new floating point value
pub fn new_float(f: f64) -> Self {
FileSourceInner::Float(f)
}
/// Creates a new field element value
pub fn new_field(f: Fp) -> Self {
FileSourceInner::Field(f)
}
/// Creates a new boolean value
pub fn new_bool(f: bool) -> Self {
FileSourceInner::Bool(f)
}
/// Adjusts the value according to the specified input type
///
/// # Arguments
/// * `input_type` - Type specification to convert the value to
pub fn as_type(&mut self, input_type: &InputType) {
match self {
FileSourceInner::Float(f) => input_type.roundtrip(f),
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
FileSourceInner::Field(_) => {}
}
}
/// Converts the value to a field element using appropriate scaling
///
/// # Arguments
/// * `scale` - Scaling factor for floating point conversion
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => {
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
}
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
} else {
Fp::zero()
}
}
FileSourceInner::Field(f) => *f,
}
}
/// Converts the value to a floating point number
pub fn to_float(&self) -> f64 {
match self {
FileSourceInner::Float(f) => *f,
FileSourceInner::Bool(f) => {
if *f {
1.0
} else {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
impl Serialize for FileSourceInner {
@@ -136,8 +71,8 @@ impl Serialize for FileSourceInner {
}
}
// Deserialization implementation for FileSourceInner
// Uses JSON deserialization to handle the different variants
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
impl<'de> Deserialize<'de> for FileSourceInner {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@@ -164,100 +99,177 @@ impl<'de> Deserialize<'de> for FileSourceInner {
}
}
/// A collection of input values from a file source
/// Organized as a vector of vectors where each inner vector represents a row/entry
/// Elements of inputs coming from a file
pub type FileSource = Vec<Vec<FileSourceInner>>;
/// Represents different types of calls for fetching on-chain data
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum Calls {
/// Multiple calls to different accounts, each returning individual values
Multiple(Vec<CallsToAccount>),
/// Single call returning an array of values
Single(CallToAccount),
}
impl Default for Calls {
fn default() -> Self {
Calls::Multiple(Vec::new())
impl FileSourceInner {
/// Create a new FileSourceInner
pub fn new_float(f: f64) -> Self {
FileSourceInner::Float(f)
}
/// Create a new FileSourceInner
pub fn new_field(f: Fp) -> Self {
FileSourceInner::Field(f)
}
/// Create a new FileSourceInner
pub fn new_bool(f: bool) -> Self {
FileSourceInner::Bool(f)
}
}
impl Serialize for Calls {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
///
pub fn as_type(&mut self, input_type: &InputType) {
match self {
Calls::Single(data) => data.serialize(serializer),
Calls::Multiple(data) => data.serialize(serializer),
FileSourceInner::Float(f) => input_type.roundtrip(f),
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
FileSourceInner::Field(_) => {}
}
}
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
} else {
Fp::zero()
}
}
FileSourceInner::Field(f) => *f,
}
}
/// Convert to a float
pub fn to_float(&self) -> f64 {
match self {
FileSourceInner::Float(f) => *f,
FileSourceInner::Bool(f) => {
if *f {
1.0
} else {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
impl<'de> Deserialize<'de> for Calls {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
if let Ok(t) = multiple_try {
return Ok(Calls::Multiple(t));
}
let single_try: Result<CallToAccount, _> = serde_json::from_str(this_json.get());
if let Ok(t) = single_try {
return Ok(Calls::Single(t));
}
Err(serde::de::Error::custom("failed to deserialize Calls"))
}
}
/// Configuration for accessing on-chain data sources
/// Inner elements of inputs/outputs coming from on-chain
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct OnChainSource {
/// Call specifications for fetching data
pub calls: Calls,
/// RPC endpoint URL for accessing the chain
/// Vector of calls to accounts
pub calls: Vec<CallsToAccount>,
/// RPC url
pub rpc: RPCUrl,
}
impl OnChainSource {
/// Creates a new OnChainSource with multiple calls
///
/// # Arguments
/// * `calls` - Vector of call specifications
/// * `rpc` - RPC endpoint URL
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Multiple(calls),
rpc,
/// Create a new OnChainSource
pub fn new(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
OnChainSource { calls, rpc }
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Inner elements of inputs/outputs coming from postgres DB
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// postgres host
pub host: RPCUrl,
/// user to connect to postgres
pub user: String,
/// password to connect to postgres
pub password: String,
/// query to execute
pub query: String,
/// dbname
pub dbname: String,
/// port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Create a new PostgresSource
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Creates a new OnChainSource with a single call
///
/// # Arguments
/// * `call` - Call specification
/// * `rpc` - RPC endpoint URL
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Single(call),
rpc,
/// Fetch data from postgres
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
let query = self.query.clone();
let dbname = self.dbname.clone();
let port = self.port.clone();
let password = self.password.clone();
let config = if password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
host, user, dbname, port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
host, user, dbname, port, password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).await? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetch data from postgres and format it as a FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
impl OnChainSource {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Creates test data for the OnChain data source
/// Used for testing and development purposes
///
/// # Arguments
/// * `data` - Sample file data to use
/// * `scales` - Scaling factors for each input
/// * `shapes` - Shapes of the input tensors
/// * `rpc` - Optional RPC endpoint override
/// Create dummy local on-chain data to test the OnChain data source
pub async fn test_from_file_data(
data: &FileSource,
scales: Vec<crate::Scale>,
@@ -265,8 +277,7 @@ impl OnChainSource {
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
use crate::eth::{
evm_quantize_multi, read_on_chain_inputs_multi, test_on_chain_data,
DEFAULT_ANVIL_ENDPOINT,
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
};
use log::debug;
@@ -285,7 +296,7 @@ impl OnChainSource {
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
debug!("Calls to accounts: {:?}", calls_to_accounts);
let inputs =
read_on_chain_inputs_multi(client.clone(), client_address, &calls_to_accounts).await?;
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
debug!("Inputs: {:?}", inputs);
let mut quantized_evm_inputs = vec![];
@@ -293,7 +304,7 @@ impl OnChainSource {
let mut prev = 0;
for (idx, i) in data.iter().enumerate() {
quantized_evm_inputs.extend(
evm_quantize_multi(
evm_quantize(
client.clone(),
vec![scales[idx]; i.len()],
&(
@@ -319,45 +330,35 @@ impl OnChainSource {
// Fill the input_data field of the GraphData struct
Ok((
inputs,
OnChainSource::new_multiple(calls_to_accounts.clone(), used_rpc),
OnChainSource::new(calls_to_accounts.clone(), used_rpc),
))
}
}
/// Specification for view-only calls to fetch on-chain data
/// Used for data attestation in smart contract verification
/// Defines the view only calls to accounts to fetch the on-chain input data.
/// This data will be included as part of the first elements in the publicInputs
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallsToAccount {
/// Vector of (call data, decimals) pairs
/// call_data: ABI-encoded function call
/// decimals: Number of decimal places for float conversion
/// A vector of tuples, where index 0 of tuples
/// are the byte strings representing the ABI encoded function calls to
/// read the data from the address. This call must return a single
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
/// The second index of the tuple is the number of decimals for f32 conversion.
/// We don't support dynamic types currently.
pub call_data: Vec<(Call, Decimals)>,
/// Contract address to call
/// Address of the contract to read the data from.
pub address: String,
}
/// Specification for a single view-only call returning an array
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallToAccount {
/// ABI-encoded function call data
pub call_data: Call,
/// Number of decimal places for float conversion
pub decimals: Decimals,
/// Contract address to call
pub address: String,
/// Expected length of returned array
pub len: usize,
}
/// Represents different sources of input/output data for the EZKL model
/// Enum that defines source of the inputs/outputs to the EZKL model
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
#[serde(untagged)]
pub enum DataSource {
/// Data from a JSON file containing arrays of values
/// .json File data source.
File(FileSource),
/// Data fetched from blockchain contracts
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
OnChain(OnChainSource),
/// Data from a PostgreSQL database
/// Postgres DB
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
@@ -400,7 +401,8 @@ impl From<OnChainSource> for DataSource {
}
}
// Note: Always use JSON serialization for untagged enums
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
impl<'de> Deserialize<'de> for DataSource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@@ -408,19 +410,15 @@ impl<'de> Deserialize<'de> for DataSource {
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
// Try deserializing as FileSource first
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = first_try {
return Ok(DataSource::File(t));
}
// Try deserializing as OnChainSource
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = second_try {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
@@ -433,29 +431,22 @@ impl<'de> Deserialize<'de> for DataSource {
}
}
/// Container for input and output data for graph computations
///
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
/// Input to graph as a datasource
/// Always use JSON serialization for GraphData. Seriously.
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
pub struct GraphData {
/// Input data for the model/graph
/// Can be empty if inputs come from on-chain sources
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
pub input_data: DataSource,
/// Optional output data for the model/graph
/// Can be empty if outputs come from on-chain sources
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
pub output_data: Option<DataSource>,
}
impl UnwindSafe for GraphData {}
impl GraphData {
// not wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Converts the input data to tract's tensor format
///
/// # Arguments
/// * `shapes` - Expected shapes for each input tensor
/// * `datum_types` - Expected data types for each input
/// Convert the input data to tract data
pub fn to_tract_data(
&self,
shapes: &[Vec<usize>],
@@ -484,43 +475,7 @@ impl GraphData {
Ok(inputs)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Converts tract tensor data into GraphData format
///
/// # Arguments
/// * `tensors` - Array of tract tensors to convert
///
/// # Returns
/// A new GraphData instance containing the converted tensor data
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
use tract_onnx::prelude::DatumType;
let mut input_data = vec![];
for tensor in tensors {
match tensor.datum_type() {
tract_onnx::prelude::DatumType::Bool => {
let tensor = tensor.to_array_view::<bool>()?;
let tensor = tensor.iter().map(|e| FileSourceInner::Bool(*e)).collect();
input_data.push(tensor);
}
_ => {
let cast_tensor = tensor.cast_to_dt(DatumType::F64)?;
let tensor = cast_tensor.to_array_view::<f64>()?;
let tensor = tensor.iter().map(|e| FileSourceInner::Float(*e)).collect();
input_data.push(tensor);
}
}
}
Ok(GraphData {
input_data: DataSource::File(input_data),
output_data: None,
})
}
/// Creates a new GraphData instance with given input data
///
/// # Arguments
/// * `input_data` - The input data source
pub fn new(input_data: DataSource) -> Self {
GraphData {
input_data,
@@ -528,13 +483,7 @@ impl GraphData {
}
}
/// Loads graph input data from a file
///
/// # Arguments
/// * `path` - Path to the input file
///
/// # Returns
/// A new GraphData instance containing the loaded data
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
@@ -548,35 +497,23 @@ impl GraphData {
Ok(graph_input)
}
/// Saves the graph data to a file
///
/// # Arguments
/// * `path` - Path where to save the data
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, self)?;
Ok(())
}
/// Splits the input data into multiple batches based on input shapes
///
/// # Arguments
/// * `input_shapes` - Vector of shapes for each input tensor
///
/// # Returns
/// Vector of GraphData instances, one for each batch
///
/// # Errors
/// Returns error if:
/// - Data is from on-chain source
/// - Input size is not evenly divisible by batch size
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, GraphError> {
// split input data into batches
let mut batched_inputs = vec![];
let iterable = match self {
@@ -600,12 +537,10 @@ impl GraphData {
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
for (i, shape) in input_shapes.iter().enumerate() {
// ensure the input is evenly divisible by batch_size
let input_size = shape.clone().iter().product::<usize>();
let input = &iterable[i];
// Validate input size is divisible by batch size
if input.len() % input_size != 0 {
return Err(GraphError::InvalidDims(
0,
@@ -613,8 +548,6 @@ impl GraphData {
.to_string(),
));
}
// Split input into batches
let mut batches = vec![];
for batch in input.chunks(input_size) {
batches.push(batch.to_vec());
@@ -622,18 +555,18 @@ impl GraphData {
batched_inputs.push(batches);
}
// Merge batches across inputs
// now merge all the batches for each input into a vector of batches
// first assert each input has the same number of batches
let num_batches = if batched_inputs.is_empty() {
0
} else {
let num_batches = batched_inputs[0].len();
// Verify all inputs have same number of batches
for input in batched_inputs.iter() {
assert_eq!(input.len(), num_batches);
}
num_batches
};
// now merge the batches
let mut input_batches = vec![];
for i in 0..num_batches {
let mut batch = vec![];
@@ -643,12 +576,11 @@ impl GraphData {
input_batches.push(DataSource::File(batch));
}
// Ensure at least one batch exists
if input_batches.is_empty() {
input_batches.push(DataSource::File(vec![vec![]]));
}
// Create GraphData instance for each batch
// create a new GraphWitness for each batch
let batches = input_batches
.into_iter()
.map(GraphData::new)
@@ -660,7 +592,6 @@ impl GraphData {
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallsToAccount {
/// Converts CallsToAccount to Python object
fn to_object(&self, py: Python) -> PyObject {
let dict = PyDict::new(py);
dict.set_item("account", &self.address).unwrap();
@@ -669,187 +600,6 @@ impl ToPyObject for CallsToAccount {
}
}
// Additional Python bindings for various types...
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
let serialized = serde_json::to_string(&source).unwrap();
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
assert_eq!(serialized, JSON);
let expect = serde_json::from_str::<DataSource>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(expect, source);
}
#[test]
fn test_graph_input_serialization_round_trip() {
// Test serialization/deserialization of graph input
let file = GraphData::new(DataSource::from(vec![vec![
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let serialized = serde_json::to_string(&file).unwrap();
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
assert_eq!(serialized, JSON);
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(graph_input3, file);
}
#[test]
fn test_python_compat() {
// Test compatibility with mclbn256 library serialization
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
assert_eq!(format!("{:?}", source), original_addr);
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
let dict = PyDict::new(py);
dict.set_item("account", &self.address).unwrap();
dict.set_item("call_data", &self.call_data).unwrap();
dict.set_item("decimals", &self.decimals).unwrap();
dict.set_item("len", &self.len).unwrap();
dict.to_object(py)
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for Calls {
fn to_object(&self, py: Python) -> PyObject {
match self {
Calls::Multiple(calls) => calls.to_object(py),
Calls::Single(call) => call.to_object(py),
}
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for DataSource {
fn to_object(&self, py: Python) -> PyObject {
@@ -858,11 +608,9 @@ impl ToPyObject for DataSource {
DataSource::OnChain(source) => {
let dict = PyDict::new(py);
dict.set_item("rpc_url", &source.rpc).unwrap();
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
.unwrap();
dict.set_item("calls_to_accounts", &source.calls).unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
@@ -887,3 +635,69 @@ impl ToPyObject for FileSourceInner {
}
}
}
impl Serialize for GraphData {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("GraphData", 4)?;
state.serialize_field("input_data", &self.input_data)?;
state.serialize_field("output_data", &self.output_data)?;
state.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
// this is for backwards compatibility with the old format
fn test_data_source_serialization_round_trip() {
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
let serialized = serde_json::to_string(&source).unwrap();
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
assert_eq!(serialized, JSON);
let expect = serde_json::from_str::<DataSource>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(expect, source);
}
#[test]
// this is for backwards compatibility with the old format
fn test_graph_input_serialization_round_trip() {
let file = GraphData::new(DataSource::from(vec![vec![
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let serialized = serde_json::to_string(&file).unwrap();
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
assert_eq!(serialized, JSON);
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(graph_input3, file);
}
// test for the compatibility with the serialized elements from the mclbn256 library
#[test]
fn test_python_compat() {
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
assert_eq!(format!("{:?}", source), original_addr);
}
}

View File

@@ -60,10 +60,7 @@ use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
pub use utilities::*;
@@ -136,8 +133,6 @@ pub struct GraphWitness {
pub min_lookup_inputs: IntegerRep,
/// max range check size
pub max_range_size: IntegerRep,
/// (optional) version of ezkl used
pub version: Option<String>,
}
impl GraphWitness {
@@ -166,7 +161,6 @@ impl GraphWitness {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
version: None,
}
}
@@ -280,13 +274,7 @@ impl GraphWitness {
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let witness: GraphWitness =
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
// check versions match
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
Ok(witness)
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
@@ -352,10 +340,10 @@ impl ToPyObject for GraphWitness {
if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
insert_poseidon_hash_pydict(&dict_inputs, processed_inputs_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
}
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
insert_polycommit_pydict(&dict_inputs, processed_inputs_polycommit).unwrap();
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
}
dict.set_item("processed_inputs", dict_inputs).unwrap();
@@ -363,10 +351,10 @@ impl ToPyObject for GraphWitness {
if let Some(processed_params) = &self.processed_params {
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
insert_poseidon_hash_pydict(&dict_params, processed_params_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
}
if let Some(processed_params_polycommit) = &processed_params.polycommit {
insert_polycommit_pydict(&dict_params, processed_params_polycommit).unwrap();
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
}
dict.set_item("processed_params", dict_params).unwrap();
@@ -374,11 +362,10 @@ impl ToPyObject for GraphWitness {
if let Some(processed_outputs) = &self.processed_outputs {
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
insert_poseidon_hash_pydict(&dict_outputs, processed_outputs_poseidon_hash)
.unwrap();
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
}
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
insert_polycommit_pydict(&dict_outputs, processed_outputs_polycommit).unwrap();
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
}
dict.set_item("processed_outputs", dict_outputs).unwrap();
@@ -389,10 +376,7 @@ impl ToPyObject for GraphWitness {
}
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(
pydict: &Bound<'_, PyDict>,
poseidon_hash: &Vec<Fp>,
) -> Result<(), PyErr> {
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
@@ -400,10 +384,7 @@ fn insert_poseidon_hash_pydict(
}
#[cfg(feature = "python-bindings")]
fn insert_polycommit_pydict(
pydict: &Bound<'_, PyDict>,
commits: &Vec<Vec<G1Affine>>,
) -> Result<(), PyErr> {
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
use crate::bindings::python::PyG1Affine;
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
.iter()
@@ -578,14 +559,10 @@ impl GraphSettings {
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})?;
crate::check_version_string_matches(&settings.version);
Ok(settings)
})
}
/// Export the ezkl configuration as json
@@ -619,6 +596,11 @@ impl GraphSettings {
}
}
///
pub fn uses_modules(&self) -> bool {
!self.module_sizes.max_constraints() > 0
}
/// if any visibility is encrypted or hashed
pub fn module_requires_fixed(&self) -> bool {
self.run_args.input_visibility.is_hashed()
@@ -702,9 +684,6 @@ impl GraphCircuit {
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
// check the versions matche
crate::check_version_string_matches(&result.core.settings.version);
Ok(result)
}
}
@@ -761,7 +740,7 @@ pub struct TestOnChainData {
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
/// data sources for the on chain data
///
pub data_sources: TestSources,
}
@@ -949,7 +928,7 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => Err(GraphError::OnChainDataSource),
_ => unreachable!("cannot load from on-chain data"),
}
}
@@ -1022,24 +1001,11 @@ impl GraphCircuit {
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{
evm_quantize_multi, evm_quantize_single, read_on_chain_inputs_multi,
read_on_chain_inputs_single, setup_eth_backend,
};
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let quantized_evm_inputs = match source.calls {
input::Calls::Single(call) => {
let (inputs, decimals) =
read_on_chain_inputs_single(client.clone(), client_address, call).await?;
evm_quantize_single(client, scales, &inputs, decimals).await?
}
input::Calls::Multiple(calls) => {
let inputs =
read_on_chain_inputs_multi(client.clone(), client_address, &calls).await?;
evm_quantize_multi(client, scales, &inputs).await?
}
};
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
// quantize the supplied data using the provided scale + QuantizeData.sol
let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
@@ -1384,7 +1350,6 @@ impl GraphCircuit {
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_size: model_results.max_range_size,
version: Some(crate::version().to_string()),
};
witness.generate_rescaled_elements(

View File

@@ -384,7 +384,8 @@ pub struct ParsedNodes {
impl ParsedNodes {
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
self.inputs.len()
let input_nodes = self.inputs.iter();
input_nodes.len()
}
/// Input types
@@ -424,7 +425,8 @@ impl ParsedNodes {
/// Returns the number of the computational graph's outputs
pub fn num_outputs(&self) -> usize {
self.outputs.len()
let output_nodes = self.outputs.iter();
output_nodes.len()
}
/// Returns shapes of the computational graph's outputs
@@ -619,23 +621,19 @@ impl Model {
/// * `scale` - The scale to use for quantization.
/// * `public_params` - Whether to make the params public.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub(crate) fn load_onnx_using_tract(
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
variables: &[(String, usize)],
run_args: &RunArgs,
) -> Result<TractResult, GraphError> {
use tract_onnx::tract_hir::internal::GenericFactoid;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
std::collections::HashMap::from_iter(run_args.variables.clone());
for (i, id) in model.clone().inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
for (i, x) in fact.clone().shape.dims().enumerate() {
@@ -657,8 +655,8 @@ impl Model {
}
let mut symbol_values = SymbolValues::default();
for (symbol, value) in variables.iter() {
let symbol = model.symbols.sym(symbol);
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
}
@@ -685,7 +683,7 @@ impl Model {
) -> Result<ParsedNodes, GraphError> {
let start_time = instant::Instant::now();
let (model, symbol_values) = Self::load_onnx_using_tract(reader, &run_args.variables)?;
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
let scales = VarScales::from_args(run_args);
let nodes = Self::nodes_from_graph(
@@ -908,7 +906,6 @@ impl Model {
n.opkind = SupportedOp::Input(Input {
scale,
datum_type: inp.datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
});
input_idx += 1;
n.out_scale = scale;
@@ -918,9 +915,20 @@ impl Model {
if scales.contains_key(&i) {
let scale_diff = n.out_scale - scales[&i];
n.opkind = if scale_diff > 0 {
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
RebaseScale::rebase(
n.opkind,
scales[&i],
n.out_scale,
1,
run_args.div_rebasing,
)
} else {
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
RebaseScale::rebase_up(
n.opkind,
scales[&i],
n.out_scale,
run_args.div_rebasing,
)
};
n.out_scale = scales[&i];
}
@@ -967,7 +975,7 @@ impl Model {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, &run_args.variables)?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?
@@ -1019,10 +1027,6 @@ impl Model {
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
if vars.advices.len() < 3 {
return Err(GraphError::InsufficientAdviceColumns(3));
}
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
@@ -1042,10 +1046,6 @@ impl Model {
}
if settings.requires_dynamic_lookup() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
@@ -1054,13 +1054,10 @@ impl Model {
}
if settings.requires_shuffle() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_shuffles(
meta,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
)?;
}
@@ -1075,7 +1072,6 @@ impl Model {
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
#[allow(clippy::too_many_arguments)]
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1146,8 +1142,8 @@ impl Model {
.iter()
.enumerate()
.map(|(i, output)| {
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars
@@ -1170,10 +1166,7 @@ impl Model {
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
Box::new(HybridOp::RangeCheck(tolerance)),
)
.map_err(|e| e.into())
})
@@ -1217,9 +1210,9 @@ impl Model {
// Then number of columns in the circuits
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
trace!("input indices: {:?}", node.inputs());
trace!("output scales: {:?}", node.out_scales());
trace!(
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
@@ -1238,13 +1231,12 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
trace!("output dims: {:?}", node.out_dims());
trace!(
debug!("output dims: {:?}", node.out_dims());
debug!(
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
@@ -1382,7 +1374,6 @@ impl Model {
results.insert(*idx, full_results);
}
}
debug!("------------ layout of {} took {:?}", idx, start.elapsed());
}
// we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...)
@@ -1450,16 +1441,13 @@ impl Model {
.into();
comparator.reshape(output.dims())?;
let mut tol = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>();
@@ -1481,7 +1469,7 @@ impl Model {
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
@@ -1551,7 +1539,6 @@ impl Model {
let mut op = crate::circuit::Constant::new(
c.quantized_values.clone(),
c.raw_values.clone(),
c.decomp,
);
op.pre_assign(consts[const_idx].clone());
n.opkind = SupportedOp::Constant(op);

View File

@@ -14,11 +14,14 @@ use serde::{Deserialize, Serialize};
use super::errors::GraphError;
use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
pub const POSEIDON_LEN_GRAPH: usize = 32;
/// Poseidon number of instances
pub const POSEIDON_INSTANCES: usize = 1;
/// Poseidon module type
pub type ModulePoseidon = PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>;
pub type ModulePoseidon =
PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>;
/// Poseidon module config
pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
@@ -281,6 +284,7 @@ impl GraphModules {
log::error!("Poseidon config not initialized");
return Err(Error::Synthesis);
}
// If the module is encrypted, then we need to encrypt the inputs
}
Ok(())

View File

@@ -1,19 +1,10 @@
// Import dependencies for scaling operations
use super::scale_to_multiplier;
// Import ONNX-specific utilities when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::utilities::node_output_shapes;
// Import scale management types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
// Import visibility settings for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
@@ -22,49 +13,28 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::errors::GraphError;
// Import ONNX operation conversion utilities
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::new_op_from_onnx;
// Import tensor error handling
use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
// Import serialization traits
use serde::Deserialize;
use serde::Serialize;
// Import data structures for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::BTreeMap;
// Import formatting traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::fmt;
// Import table display formatting for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tabled::Tabled;
// Import ONNX-specific types and traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::{
self,
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
};
/// Helper function to format vectors for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
if !v.is_empty() {
@@ -74,35 +44,29 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
}
}
/// Helper function to format operation kinds for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_opkind(v: &SupportedOp) -> String {
v.as_string()
}
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
/// A wrapper for an operation that has been rescaled.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Rescaled {
/// The underlying operation that needs to be rescaled
/// The operation that has to be rescaled.
pub inner: Box<SupportedOp>,
/// Vector of (index, scale) pairs defining how each input should be scaled
/// The scale of the operation's inputs.
pub scale: Vec<(usize, u128)>,
}
/// Implementation of the Op trait for Rescaled operations
impl Op<Fp> for Rescaled {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
}
/// Calculate output scale based on input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let in_scales = in_scales
.into_iter()
@@ -113,7 +77,6 @@ impl Op<Fp> for Rescaled {
Op::<Fp>::out_scale(&*self.inner, in_scales)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -130,45 +93,34 @@ impl Op<Fp> for Rescaled {
self.inner.layout(config, region, res)
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone())
Box::new(self.clone()) // Forward to the derive(Clone) impl
}
}
/// A wrapper for operations that require scale rebasing
/// This handles cases where operation scales need to be adjusted to a target scale
/// while preserving the numerical relationships
/// A wrapper for an operation that has been rescaled.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RebaseScale {
/// The operation that needs to be rescaled
/// The operation that has to be rescaled.
pub inner: Box<SupportedOp>,
/// Operation used for rebasing, typically division
/// rebase op
pub rebase_op: HybridOp,
/// Scale that we're rebasing to
/// scale being rebased to
pub target_scale: i32,
/// Original scale of operation's inputs before rebasing
/// The original scale of the operation's inputs.
pub original_scale: i32,
/// Scaling multiplier used in rebasing
/// multiplier
pub multiplier: f64,
}
impl RebaseScale {
/// Creates a rebased version of an operation if needed
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `global_scale` - Base scale for the system
/// * `op_out_scale` - Current output scale of the operation
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
///
/// # Returns
/// Original or rebased operation depending on scale relationships
pub fn rebase(
inner: SupportedOp,
global_scale: crate::Scale,
op_out_scale: crate::Scale,
scale_rebase_multiplier: u32,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
@@ -185,6 +137,7 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
})
@@ -195,6 +148,7 @@ impl RebaseScale {
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
})
@@ -204,19 +158,12 @@ impl RebaseScale {
}
}
/// Creates a rebased operation with increased scale
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `target_scale` - Scale to rebase to
/// * `op_out_scale` - Current output scale of the operation
///
/// # Returns
/// Original or rebased operation with increased scale
pub fn rebase_up(
inner: SupportedOp,
target_scale: crate::Scale,
op_out_scale: crate::Scale,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
@@ -229,6 +176,7 @@ impl RebaseScale {
original_scale: op.original_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
@@ -239,6 +187,7 @@ impl RebaseScale {
original_scale: op_out_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
@@ -249,12 +198,10 @@ impl RebaseScale {
}
impl Op<Fp> for RebaseScale {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}, rebasing_op={}) ({})",
@@ -264,12 +211,10 @@ impl Op<Fp> for RebaseScale {
)
}
/// Calculate output scale based on input scales
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.target_scale)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -283,40 +228,34 @@ impl Op<Fp> for RebaseScale {
self.rebase_op.layout(config, region, &[original_res])
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone())
Box::new(self.clone()) // Forward to the derive(Clone) impl
}
}
/// Represents all supported operation types in the circuit
/// Each variant encapsulates a different type of operation with specific behavior
/// A single operation in a [crate::graph::Model].
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SupportedOp {
/// Linear operations (polynomial-based)
/// A linear operation.
Linear(PolyOp),
/// Nonlinear operations requiring lookup tables
/// A nonlinear operation.
Nonlinear(LookupOp),
/// Mixed operations combining different approaches
/// A hybrid operation.
Hybrid(HybridOp),
/// Input values to the circuit
///
Input(Input),
/// Constant values in the circuit
///
Constant(Constant<Fp>),
/// Placeholder for unsupported operations
///
Unknown(Unknown),
/// Operations requiring rescaling of inputs
///
Rescaled(Rescaled),
/// Operations requiring scale rebasing
///
RebaseScale(RebaseScale),
}
impl SupportedOp {
/// Checks if the operation is a lookup operation
///
/// # Returns
/// * `true` if operation requires lookup table
/// * `false` otherwise
pub fn is_lookup(&self) -> bool {
match self {
SupportedOp::Nonlinear(_) => true,
@@ -324,12 +263,7 @@ impl SupportedOp {
_ => false,
}
}
/// Returns input operation if this is an input
///
/// # Returns
/// * `Some(Input)` if this is an input operation
/// * `None` otherwise
pub fn get_input(&self) -> Option<Input> {
match self {
SupportedOp::Input(op) => Some(op.clone()),
@@ -337,11 +271,7 @@ impl SupportedOp {
}
}
/// Returns reference to rebased operation if this is a rebased operation
///
/// # Returns
/// * `Some(&RebaseScale)` if this is a rebased operation
/// * `None` otherwise
pub fn get_rebased(&self) -> Option<&RebaseScale> {
match self {
SupportedOp::RebaseScale(op) => Some(op),
@@ -349,11 +279,7 @@ impl SupportedOp {
}
}
/// Returns reference to lookup operation if this is a lookup operation
///
/// # Returns
/// * `Some(&LookupOp)` if this is a lookup operation
/// * `None` otherwise
pub fn get_lookup(&self) -> Option<&LookupOp> {
match self {
SupportedOp::Nonlinear(op) => Some(op),
@@ -361,11 +287,7 @@ impl SupportedOp {
}
}
/// Returns reference to constant if this is a constant
///
/// # Returns
/// * `Some(&Constant)` if this is a constant
/// * `None` otherwise
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -373,11 +295,7 @@ impl SupportedOp {
}
}
/// Returns mutable reference to constant if this is a constant
///
/// # Returns
/// * `Some(&mut Constant)` if this is a constant
/// * `None` otherwise
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -385,19 +303,18 @@ impl SupportedOp {
}
}
/// Creates a homogeneously rescaled version of this operation if needed
/// Only available with EZKL feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn homogenous_rescale(
&self,
in_scales: Vec<crate::Scale>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
}
/// Returns reference to underlying Op implementation
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
fn as_op(&self) -> &dyn Op<Fp> {
match self {
SupportedOp::Linear(op) => op,
@@ -411,10 +328,9 @@ impl SupportedOp {
}
}
/// Checks if this is an identity operation
///
/// check if is the identity operation
/// # Returns
/// * `true` if this operation passes input through unchanged
/// * `true` if the operation is the identity operation
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
@@ -451,11 +367,9 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
return SupportedOp::Unknown(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
return SupportedOp::Rescaled(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
return SupportedOp::RebaseScale(op.clone());
};
@@ -467,7 +381,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
/// Layout this operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -477,61 +390,54 @@ impl Op<Fp> for SupportedOp {
self.as_op().layout(config, region, values)
}
/// Check if this is an input operation
fn is_input(&self) -> bool {
self.as_op().is_input()
}
/// Check if this is a constant operation
fn is_constant(&self) -> bool {
self.as_op().is_constant()
}
/// Get which inputs require homogeneous scales
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
self.as_op().requires_homogenous_input_scales()
}
/// Create a clone of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
self.as_op().clone_dyn()
}
/// Get string representation
fn as_string(&self) -> String {
self.as_op().as_string()
}
/// Convert to Any type
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Calculate output scale from input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
self.as_op().out_scale(in_scales)
}
}
/// Represents a connection to another node's output
/// First element is node index, second is output slot index
/// A node's input is a tensor from another node's output.
pub type Outlet = (usize, usize);
/// Represents a single computational node in the circuit graph
/// Contains all information needed to execute and connect operations
/// A single operation in a [crate::graph::Model].
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Node {
/// The operation this node performs
/// [Op] i.e what operation this node represents.
pub opkind: SupportedOp,
/// Fixed point scale factor for this node's output
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
pub out_scale: i32,
/// Connections to other nodes' outputs that serve as inputs
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
// but in_dim is [in], out_dim is [out]
/// The indices of the node's inputs.
pub inputs: Vec<Outlet>,
/// Shape of this node's output tensor
/// Dimensions of output.
pub out_dims: Vec<usize>,
/// Unique identifier for this node
/// The node's unique identifier.
pub idx: usize,
/// Number of times this node's output is used
/// The node's num of uses
pub num_uses: usize,
}
@@ -569,19 +475,12 @@ impl PartialEq for Node {
}
impl Node {
/// Creates a new Node from an ONNX node
/// Only available when EZKL feature is enabled
///
/// # Arguments
/// * `node` - Source ONNX node
/// * `other_nodes` - Map of existing nodes in the graph
/// * `scales` - Scale factors for variables
/// * `idx` - Unique identifier for this node
/// * `symbol_values` - ONNX symbol values
/// * `run_args` - Runtime configuration arguments
///
/// # Returns
/// New Node instance or error if creation fails
/// Converts a tract [OnnxNode] into an ezkl [Node].
/// # Arguments:
/// * `node` - [OnnxNode]
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[allow(clippy::too_many_arguments)]
pub fn new(
@@ -696,7 +595,13 @@ impl Node {
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
opkind = RebaseScale::rebase(
opkind,
global_scale,
out_scale,
scales.rebase_multiplier,
run_args.div_rebasing,
);
out_scale = opkind.out_scale(in_scales)?;
@@ -719,14 +624,16 @@ impl Node {
})
}
/// Check if this node performs softmax operation
/// check if it is a softmax node
pub fn is_softmax(&self) -> bool {
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
true
} else {
false
}
}
}
/// Helper function to rescale constants that are only used once
/// Only available when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn rescale_const_with_single_use(
constant: &mut Constant<Fp>,

View File

@@ -44,11 +44,11 @@ use tract_onnx::tract_hir::{
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
};
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
/// Arguments
///
/// * `elem` - the element to quantize.
/// * `vec` - the vector to quantize.
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(
@@ -59,7 +59,7 @@ pub fn quantize_float(
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value || *elem < -max_value {
if *elem > max_value {
return Err(TensorError::SigBitTruncationError);
}
@@ -85,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
f64::powf(2., scale as f64)
}
/// Converts a fixed point multiplier to a scale (log base 2).
/// Converts a scale (log base 2) to a fixed point multiplier.
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
mult.log2().round() as crate::Scale
}
@@ -142,6 +142,8 @@ use tract_onnx::prelude::SymbolValues;
pub fn extract_tensor_value(
input: Arc<tract_onnx::prelude::Tensor>,
) -> Result<Tensor<f32>, GraphError> {
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
let dt = input.datum_type();
let dims = input.shape().to_vec();
@@ -154,7 +156,7 @@ pub fn extract_tensor_value(
match dt {
DatumType::F16 => {
let vec = input.as_slice::<tract_onnx::prelude::f16>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| (*x).into()).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| (*x).into()).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::F32 => {
@@ -163,61 +165,61 @@ pub fn extract_tensor_value(
}
DatumType::F64 => {
let vec = input.as_slice::<f64>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i64>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i32>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i16>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i8>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u8>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u16>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u32>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u64>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::Bool => {
// Generally a shape or hyperparam
let vec = input.as_slice::<bool>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as usize as f32).collect();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as usize as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::TDim => {
@@ -225,10 +227,13 @@ pub fn extract_tensor_value(
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
let cast: Result<Vec<f32>, GraphError> = vec
.iter()
.par_iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
Err(_) => match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
},
})
.collect();
@@ -274,10 +279,10 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use crate::circuit::InputType;
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
let input_scales = inputs
.iter()
.flat_map(|x| x.out_scales())
@@ -307,9 +312,6 @@ pub fn new_op_from_onnx(
let mut deleted_indices = vec![];
let node = match node.op().name().as_ref() {
"ShiftLeft" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -322,13 +324,10 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
}
}
"ShiftRight" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -341,7 +340,7 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
}
}
"MultiBroadcastTo" => {
@@ -364,10 +363,7 @@ pub fn new_op_from_onnx(
}
}
if input_ops.len() != 3 {
return Err(GraphError::InvalidDims(idx, "range".to_string()));
}
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
let input_ops = input_ops
.iter()
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
@@ -382,11 +378,7 @@ pub fn new_op_from_onnx(
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
!run_args.ignore_range_check_inputs_outputs,
);
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -427,10 +419,6 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
}
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| {
@@ -448,7 +436,6 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: false,
}));
inputs[1].bump_scale(0);
}
@@ -460,17 +447,8 @@ pub fn new_op_from_onnx(
"Topk" => {
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
};
// if param_visibility.is_public() {
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
if c.raw_values.len() != 1 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
}
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
c.raw_values.map(|x| x as usize)[0]
@@ -510,10 +488,6 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -525,7 +499,6 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -549,9 +522,6 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -562,7 +532,6 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -586,9 +555,6 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -600,7 +566,6 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -624,9 +589,6 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -638,7 +600,6 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -713,11 +674,7 @@ pub fn new_op_from_onnx(
constant_scale,
&run_args.param_visibility,
)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
run_args.ignore_range_check_inputs_outputs,
);
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -727,9 +684,7 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
}
assert_eq!(axes.len(), 1, "only support argmax over one axis");
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
}
@@ -739,9 +694,7 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
}
assert_eq!(axes.len(), 1, "only support argmin over one axis");
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
}
@@ -810,55 +763,93 @@ pub fn new_op_from_onnx(
.map(|(i, _)| i)
.collect::<Vec<_>>();
if inputs.len() == 2 {
if !const_inputs.is_empty() {
let const_idx = const_inputs[0];
let boxed_op = inputs[const_idx].opkind();
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
if c.len() == 1 {
c[0]
} else {
return Err(GraphError::InvalidDims(idx, "max".to_string()));
}
} else {
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
};
if unit == 0. {
if let Some(node) = inputs.get_mut(const_idx) {
node.decrement_use();
deleted_indices.push(const_idx);
}
SupportedOp::Linear(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
})
} else {
SupportedOp::Hybrid(HybridOp::Max)
}
if const_inputs.len() != 1 {
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
}
let const_idx = const_inputs[0];
let boxed_op = inputs[const_idx].opkind();
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
if c.len() == 1 {
c[0]
} else {
SupportedOp::Hybrid(HybridOp::Max)
return Err(GraphError::InvalidDims(idx, "max".to_string()));
}
} else {
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
};
if inputs.len() == 2 {
if let Some(node) = inputs.get_mut(const_idx) {
node.decrement_use();
deleted_indices.push(const_idx);
}
if unit == 0. {
SupportedOp::Linear(PolyOp::ReLU)
} else {
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Max {
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
}
} else {
return Err(GraphError::InvalidDims(idx, "max".to_string()));
}
}
"Min" => {
// Extract the min value
// first find the input that is a constant
// and then extract the value
let const_inputs = inputs
.iter()
.enumerate()
.filter(|(_, n)| n.is_constant())
.map(|(i, _)| i)
.collect::<Vec<_>>();
if const_inputs.len() != 1 {
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
}
let const_idx = const_inputs[0];
let boxed_op = inputs[const_idx].opkind();
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
if c.len() == 1 {
c[0]
} else {
return Err(GraphError::InvalidDims(idx, "min".to_string()));
}
} else {
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
};
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Min)
if let Some(node) = inputs.get_mut(const_idx) {
node.decrement_use();
deleted_indices.push(const_idx);
}
// get the non-constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
SupportedOp::Nonlinear(LookupOp::Min {
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
a: crate::circuit::utils::F32(unit),
})
} else {
return Err(GraphError::InvalidDims(idx, "min".to_string()));
}
}
"Recip" => {
if inputs.len() != 1 {
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
};
let in_scale = input_scales[0];
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
// If the input scale is larger than the params scale
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
use_range_check_for_int: true,
})
}
@@ -873,9 +864,8 @@ pub fn new_op_from_onnx(
}
};
SupportedOp::Linear(PolyOp::LeakyReLU {
SupportedOp::Nonlinear(LookupOp::LeakyReLU {
slope: crate::circuit::utils::F32(leaky_op.alpha),
scale: scales.params,
})
}
"Scan" => {
@@ -887,79 +877,61 @@ pub fn new_op_from_onnx(
"Abs" => SupportedOp::Linear(PolyOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sqrt" => SupportedOp::Hybrid(HybridOp::Sqrt {
scale: scale_to_multiplier(input_scales[0]).into(),
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Rsqrt" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
};
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[0]).into(),
base: E.into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Ln" => {
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
})
}
}
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
scale: scale_to_multiplier(input_scales[0]).into(),
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Source" => {
let dt = node.outputs[0].fact.datum_type;
@@ -980,19 +952,13 @@ pub fn new_op_from_onnx(
DatumType::F64 => (scales.input, InputType::F64),
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
};
SupportedOp::Input(crate::circuit::ops::Input {
scale,
datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
})
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
}
"Cast" => {
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
let dt = op.to;
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
};
assert_eq!(input_scales.len(), 1);
match dt {
DatumType::Bool
@@ -1009,9 +975,11 @@ pub fn new_op_from_onnx(
replace_const(
0,
0,
SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
SupportedOp::Nonlinear(LookupOp::Cast {
scale: crate::circuit::utils::F32(scale_to_multiplier(
input_scales[0],
)
as f32),
}),
)?
} else {
@@ -1042,11 +1010,6 @@ pub fn new_op_from_onnx(
if const_idx.len() == 1 {
let const_idx = const_idx[0];
if inputs.len() <= const_idx {
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
// if not divisible by 2 then we need to add a range check
@@ -1069,21 +1032,21 @@ pub fn new_op_from_onnx(
op
}
"Iff" => SupportedOp::Linear(PolyOp::Iff),
"<" => {
"Less" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"<=" => {
"LessEqual" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
">" => {
"Greater" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
@@ -1091,7 +1054,7 @@ pub fn new_op_from_onnx(
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
">=" => {
"GreaterEqual" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
@@ -1121,11 +1084,8 @@ pub fn new_op_from_onnx(
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
}
};
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
}
let in_scale = input_scales[0];
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
@@ -1163,42 +1123,18 @@ pub fn new_op_from_onnx(
pool_dims: kernel_shape.to_vec(),
})
}
"Ceil" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
}
SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Floor" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
}
SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Round" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "round".to_string()));
}
SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"RoundHalfToEven" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
}
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sign" => SupportedOp::Linear(PolyOp::Sign),
"Pow" => {
// Extract the slope layer hyperparams from a const
@@ -1208,92 +1144,14 @@ pub fn new_op_from_onnx(
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.len() > 1 {
return Err(GraphError::NonScalarPower);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
unimplemented!("only support scalar pow")
}
let exponent = c.raw_values[0];
if exponent.fract() == 0.0 {
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
} else {
SupportedOp::Nonlinear(LookupOp::Pow {
scale: scale_to_multiplier(input_scales[0]).into(),
a: crate::circuit::utils::F32(exponent),
})
}
} else if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
return Err(GraphError::NonScalarBase);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
let base = c.raw_values[0];
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
SupportedOp::Nonlinear(LookupOp::Pow {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
a: crate::circuit::utils::F32(c.raw_values[0]),
})
} else {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
}
"Div" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = inputs
.iter()
.enumerate()
.filter(|(_, n)| n.is_constant())
.map(|(i, _)| i)
.collect::<Vec<_>>();
if const_idx.len() > 1 || const_idx.is_empty() {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = const_idx[0];
if const_idx != 1 {
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] != 0. {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
// get the non constant index
let denom = c.raw_values[0];
let op = SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
});
// if the input is scale 0 we re up to the max scale
if input_scales[0] == 0 {
SupportedOp::Rescaled(Rescaled {
inner: Box::new(op),
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
})
} else {
op
}
} else {
return Err(GraphError::MisformedParams(
"only support non zero divisors of size 1".to_string(),
));
}
} else {
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
unimplemented!("only support constant pow for now")
}
}
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
@@ -1356,7 +1214,7 @@ pub fn new_op_from_onnx(
"And" => SupportedOp::Linear(PolyOp::And),
"Or" => SupportedOp::Linear(PolyOp::Or),
"Xor" => SupportedOp::Linear(PolyOp::Xor),
"==" => SupportedOp::Hybrid(HybridOp::Equals),
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
"Deconv" => {
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
Some(b) => b,
@@ -1433,7 +1291,7 @@ pub fn new_op_from_onnx(
if !resize_node.contains("interpolator: Nearest")
&& !resize_node.contains("nearest: Floor")
{
return Err(GraphError::InvalidInterpolation);
unimplemented!("Only nearest neighbor interpolation is supported")
}
// check if optional scale factor is present
if inputs.len() != 2 && inputs.len() != 3 {
@@ -1537,10 +1395,6 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Reshape(output_shape))
}
"Flatten" => {
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
};
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
SupportedOp::Linear(PolyOp::Flatten(new_dims))
}
@@ -1614,10 +1468,12 @@ pub fn homogenize_input_scales(
input_scales: Vec<crate::Scale>,
inputs_to_scale: Vec<usize>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let relevant_input_scales = inputs_to_scale
.iter()
.filter(|idx| input_scales.len() > **idx)
.map(|&idx| input_scales[idx])
let relevant_input_scales = input_scales
.clone()
.into_iter()
.enumerate()
.filter(|(idx, _)| inputs_to_scale.contains(idx))
.map(|(_, scale)| scale)
.collect_vec();
if inputs_to_scale.is_empty() {
@@ -1658,30 +1514,10 @@ pub fn homogenize_input_scales(
}
#[cfg(test)]
/// tests for the utility module
pub mod tests {
use super::*;
// quantization tests
#[test]
fn test_quantize_tensor() {
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}
#[test]
fn test_quantize_edge_cases() {
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
}
#[test]
fn test_flatten_valtensors() {
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();

View File

@@ -9,36 +9,38 @@ use itertools::Itertools;
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
exceptions::PyValueError, types::PyString, FromPyObject, IntoPy, PyAny, PyObject, PyResult,
PyTryFrom, Python, ToPyObject,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use self::errors::GraphError;
use super::*;
/// Defines the visibility level of values within the zero-knowledge circuit
/// Controls how values are handled during proof generation and verification
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum Visibility {
/// Value is private to the prover and not included in proof
/// Mark an item as private to the prover (not in the proof submitted for verification)
#[default]
Private,
/// Value is public and included in proof for verification
/// Mark an item as public (sent in the proof submitted for verification)
Public,
/// Value is hashed and the hash is included in proof
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
Hashed {
/// Controls how the hash is handled in proof
/// true - hash is included directly in proof (public)
/// false - hash is used as advice and passed to computational graph
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
hash_is_public: bool,
/// Specifies which outputs this hash affects
///
outlets: Vec<usize>,
},
/// Value is committed using KZG commitment scheme
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
KZGCommit,
/// Value is assigned as a constant in the circuit
/// assigned as a constant in the circuit
Fixed,
}
@@ -65,17 +67,15 @@ impl Display for Visibility {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Visibility {
/// Converts visibility to command line flags
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
/// Converts string representation to Visibility
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
// Split on last occurrence of '/'
// split on last occurrence of '/'
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -107,8 +107,8 @@ impl<'a> From<&'a str> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
impl IntoPy<PyObject> for Visibility {
/// Converts Visibility to Python object
fn into_py(self, py: Python) -> PyObject {
match self {
Visibility::Private => "private".to_object(py),
@@ -135,13 +135,16 @@ impl IntoPy<PyObject> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
impl<'source> FromPyObject<'source> for Visibility {
/// Extracts Visibility from Python object
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
let strval = strval.as_str();
if strval.contains("hashed/private") {
// split on last occurence of '/'
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -174,32 +177,29 @@ impl<'source> FromPyObject<'source> for Visibility {
}
impl Visibility {
/// Returns true if visibility is Fixed
#[allow(missing_docs)]
pub fn is_fixed(&self) -> bool {
matches!(&self, Visibility::Fixed)
}
/// Returns true if visibility is Private or hashed private
#[allow(missing_docs)]
pub fn is_private(&self) -> bool {
matches!(&self, Visibility::Private) || self.is_hashed_private()
}
/// Returns true if visibility is Public
#[allow(missing_docs)]
pub fn is_public(&self) -> bool {
matches!(&self, Visibility::Public)
}
/// Returns true if visibility involves hashing
#[allow(missing_docs)]
pub fn is_hashed(&self) -> bool {
matches!(&self, Visibility::Hashed { .. })
}
/// Returns true if visibility uses KZG commitment
#[allow(missing_docs)]
pub fn is_polycommit(&self) -> bool {
matches!(&self, Visibility::KZGCommit)
}
/// Returns true if visibility is hashed with public hash
#[allow(missing_docs)]
pub fn is_hashed_public(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: true,
@@ -210,8 +210,7 @@ impl Visibility {
}
false
}
/// Returns true if visibility is hashed with private hash
#[allow(missing_docs)]
pub fn is_hashed_private(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: false,
@@ -223,12 +222,11 @@ impl Visibility {
false
}
/// Returns true if visibility requires additional processing
#[allow(missing_docs)]
pub fn requires_processing(&self) -> bool {
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
}
/// Returns vector of output indices that this visibility setting affects
#[allow(missing_docs)]
pub fn overwrites_inputs(&self) -> Vec<usize> {
if let Visibility::Hashed { outlets, .. } = self {
return outlets.clone();
@@ -237,14 +235,14 @@ impl Visibility {
}
}
/// Manages scaling factors for different parts of the model
/// Represents the scale of the model input, model parameters.
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarScales {
/// Scale factor for input values
///
pub input: crate::Scale,
/// Scale factor for parameter values
///
pub params: crate::Scale,
/// Multiplier for scale rebasing
///
pub rebase_multiplier: u32,
}
@@ -255,17 +253,17 @@ impl std::fmt::Display for VarScales {
}
impl VarScales {
/// Returns maximum scale value
///
pub fn get_max(&self) -> crate::Scale {
std::cmp::max(self.input, self.params)
}
/// Returns minimum scale value
///
pub fn get_min(&self) -> crate::Scale {
std::cmp::min(self.input, self.params)
}
/// Creates VarScales from runtime arguments
/// Place in [VarScales] struct.
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
@@ -275,17 +273,16 @@ impl VarScales {
}
}
/// Controls visibility settings for different parts of the model
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarVisibility {
/// Visibility of model inputs
/// Input to the model or computational graph
pub input: Visibility,
/// Visibility of model parameters (weights, biases)
/// Parameters, such as weights and biases, in the model
pub params: Visibility,
/// Visibility of model outputs
/// Output of the model or computational graph
pub output: Visibility,
}
impl std::fmt::Display for VarVisibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
@@ -307,7 +304,8 @@ impl Default for VarVisibility {
}
impl VarVisibility {
/// Creates visibility settings from runtime arguments
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
let input_vis = &args.input_visibility;
let params_vis = &args.param_visibility;
@@ -318,17 +316,17 @@ impl VarVisibility {
}
if !output_vis.is_public()
&& !params_vis.is_public()
&& !input_vis.is_public()
&& !output_vis.is_fixed()
&& !params_vis.is_fixed()
&& !input_vis.is_fixed()
&& !output_vis.is_hashed()
&& !params_vis.is_hashed()
&& !input_vis.is_hashed()
&& !output_vis.is_polycommit()
&& !params_vis.is_polycommit()
&& !input_vis.is_polycommit()
& !params_vis.is_public()
& !input_vis.is_public()
& !output_vis.is_fixed()
& !params_vis.is_fixed()
& !input_vis.is_fixed()
& !output_vis.is_hashed()
& !params_vis.is_hashed()
& !input_vis.is_hashed()
& !output_vis.is_polycommit()
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
{
return Err(GraphError::Visibility);
}
@@ -340,17 +338,17 @@ impl VarVisibility {
}
}
/// Container for circuit columns used by a model
/// A wrapper for holding all columns that will be assigned to by a model.
#[derive(Clone, Debug)]
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
/// Advice columns for circuit assignments
#[allow(missing_docs)]
pub advices: Vec<VarTensor>,
/// Optional instance column for public inputs
#[allow(missing_docs)]
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Gets reference to instance column if it exists
/// Get instance col
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {
match instance {
@@ -362,14 +360,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Sets initial offset for instance values
/// Set the initial instance offset
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let Some(instance) = &mut self.instance {
instance.set_initial_instance_offset(offset);
}
}
/// Gets total length of instance data
/// Get the total instance len
pub fn get_instance_len(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_total_instance_len()
@@ -378,21 +376,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Increments instance index
/// Increment the instance offset
pub fn increment_instance_idx(&mut self) {
if let Some(instance) = &mut self.instance {
instance.increment_idx();
}
}
/// Sets instance index to specific value
/// Reset the instance offset
pub fn set_instance_idx(&mut self, val: usize) {
if let Some(instance) = &mut self.instance {
instance.set_idx(val);
}
}
/// Gets current instance index
/// Get the instance offset
pub fn get_instance_idx(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_idx()
@@ -401,7 +399,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Initializes instance column with specified dimensions and scale
///
pub fn instantiate_instance(
&mut self,
cs: &mut ConstraintSystem<F>,
@@ -422,7 +420,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
};
}
/// Creates new ModelVars with allocated columns based on settings
/// Allocate all columns that will be assigned to by a model.
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
@@ -440,7 +438,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
.collect_vec();
if requires_dynamic_lookup || requires_shuffle {
let num_cols = 3;
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);

View File

@@ -28,9 +28,6 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -102,7 +99,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use graph::Visibility;
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
@@ -111,19 +108,7 @@ use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
/// The version of the ezkl library
const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Get the version of the library
pub fn version() -> &'static str {
match VERSION {
"0.0.0" => "source - no compatibility guaranteed",
_ => VERSION,
}
}
/// Bindings management
/// Bindings managment
#[cfg(any(
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown"),
@@ -168,6 +153,7 @@ pub mod srs_sha;
pub mod tensor;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
@@ -182,9 +168,11 @@ lazy_static! {
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
@@ -266,105 +254,74 @@ impl From<String> for Commitments {
}
/// Parameters specific to a proving run
///
/// RunArgs contains all configuration parameters needed to control the proving process,
/// including scaling factors, visibility settings, and circuit parameters.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// Error tolerance for model outputs
/// Only applicable when outputs are public
/// The tolerance for error on model outputs
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// Fixed point scaling factor for quantizing inputs
/// Higher values provide more precision but increase circuit complexity
/// The denominator in the fixed point representation used when quantizing inputs
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub input_scale: Scale,
/// Fixed point scaling factor for quantizing parameters
/// Higher values provide more precision but increase circuit complexity
/// The denominator in the fixed point representation used when quantizing parameters
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub param_scale: Scale,
/// Scale rebase threshold multiplier
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
/// Advanced parameter that should be used with caution
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
pub scale_rebase_multiplier: u32,
/// Range for lookup table input column values
/// Specified as (min, max) pair
/// The min and max elements in the lookup table input column
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
pub lookup_range: Range,
/// Log2 of the number of rows in the circuit
/// Controls circuit size and proving time
/// The log_2 number of rows
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
pub logrows: u32,
/// Number of inner columns per block
/// Affects circuit layout and efficiency
/// The log_2 number of rows
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
pub num_inner_cols: usize,
/// Graph variables for parameterizing the computation
/// Format: "name->value", e.g. "batch_size->1"
/// Hand-written parser for graph variables, eg. batch_size=1
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
pub variables: Vec<(String, usize)>,
/// Visibility setting for input values
/// Controls whether inputs are public or private in the circuit
/// Flags whether inputs are public, private, fixed, hashed, polycommit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub input_visibility: Visibility,
/// Visibility setting for output values
/// Controls whether outputs are public or private in the circuit
/// Flags whether outputs are public, private, fixed, hashed, polycommit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
pub output_visibility: Visibility,
/// Visibility setting for parameters
/// Controls how parameters are handled in the circuit
/// Flags whether params are fixed, private, hashed, polycommit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub param_visibility: Visibility,
/// Whether to rebase constants with zero fractional part to scale 0
/// Can improve efficiency for integer constants
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
/// Should constants with 0.0 fraction be rebased to scale 0
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub rebase_frac_zero_constants: bool,
/// Circuit checking mode
/// Controls level of constraint verification
/// check mode (safe, unsafe, etc)
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
pub check_mode: CheckMode,
/// Commitment scheme for circuit proving
/// Affects proof size and verification time
/// commitment scheme
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
pub commitment: Option<Commitments>,
/// Base for number decomposition
/// Must be a power of 2
/// the base used for decompositions
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
pub decomp_base: usize,
/// Number of decomposition legs
/// Controls decomposition granularity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
/// the number of legs used for decompositions
pub decomp_legs: usize,
/// Whether to use bounded lookup for logarithm computation
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub bounded_log_lookup: bool,
/// Range check inputs and outputs (turn off if the inputs are felts)
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
}
impl Default for RunArgs {
/// Creates a new RunArgs instance with default values
///
/// Default configuration is optimized for common use cases
/// while maintaining reasonable proving time and circuit size
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
@@ -376,143 +333,60 @@ impl Default for RunArgs {
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
}
}
}
impl RunArgs {
/// Validates the RunArgs configuration
///
/// Performs comprehensive validation of all parameters to ensure they are within
/// acceptable ranges and follow required constraints. Returns accumulated errors
/// if any validations fail.
///
/// # Returns
/// - Ok(()) if all validations pass
/// - Err(String) with detailed error message if any validation fails
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// Visibility validations
if self.param_visibility == Visibility::Public {
errors.push(
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
.to_string(),
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
.into(),
);
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
}
// Scale validations
if self.scale_rebase_multiplier < 1 {
errors.push("scale_rebase_multiplier must be >= 1".to_string());
return Err("scale_rebase_multiplier must be >= 1".into());
}
// if any of the scales are too small
if self.input_scale < 8 || self.param_scale < 8 {
warn!("low scale values (<8) may impact precision");
}
// Lookup range validations
if self.lookup_range.0 > self.lookup_range.1 {
errors.push(format!(
"Invalid lookup range: min ({}) is greater than max ({})",
self.lookup_range.0, self.lookup_range.1
));
return Err("lookup_range min is greater than max".into());
}
// Size validations
if self.logrows < 1 {
errors.push("logrows must be >= 1".to_string());
return Err("logrows must be >= 1".into());
}
if self.num_inner_cols < 1 {
errors.push("num_inner_cols must be >= 1".to_string());
return Err("num_inner_cols must be >= 1".into());
}
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
if let Some(batch_size) = batch_size {
if batch_size.1 == 0 {
errors.push("'batch_size' cannot be 0".to_string());
}
}
// Decomposition validations
if self.decomp_base == 0 {
errors.push("decomp_base cannot be 0".to_string());
}
if self.decomp_legs == 0 {
errors.push("decomp_legs cannot be 0".to_string());
}
// Performance validations
if self.logrows > MAX_PUBLIC_SRS {
warn!("logrows exceeds maximum public SRS size");
}
// Validate tolerance is non-negative
if self.tolerance.val < 0.0 {
errors.push("tolerance cannot be negative".to_string());
}
// Performance warnings
if self.input_scale > 20 || self.param_scale > 20 {
warn!("High scale values (>20) may impact performance");
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("\n"))
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
Ok(())
}
/// Exports the configuration as JSON
///
/// Serializes the RunArgs instance to a JSON string
///
/// # Returns
/// * `Ok(String)` containing JSON representation
/// * `Err` if serialization fails
/// Export the ezkl configuration as json
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let res = serde_json::to_string(&self)?;
Ok(res)
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
}
/// Parses configuration from JSON
///
/// Deserializes a RunArgs instance from a JSON string
///
/// # Arguments
/// * `arg_json` - JSON string containing configuration
///
/// # Returns
/// * `Ok(RunArgs)` if parsing succeeds
/// * `Err` if parsing fails
/// Parse an ezkl configuration from a json
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(arg_json)
}
}
// Additional helper functions for the module
/// Parses a key-value pair from a string in the format "key->value"
///
/// # Arguments
/// * `s` - Input string in the format "key->value"
///
/// # Returns
/// * `Ok((T, U))` - Parsed key and value
/// * `Err` - If parsing fails
/// Parse a single key-value pair
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn parse_key_val<T, U>(
s: &str,
@@ -525,131 +399,8 @@ where
{
let pos = s
.find("->")
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
}
/// Verifies that a version string matches the expected artifact version
/// Logs warnings for version mismatches or unversioned artifacts
///
/// # Arguments
/// * `artifact_version` - Version string from the artifact
pub fn check_version_string_matches(artifact_version: &str) {
if artifact_version == "0.0.0"
|| artifact_version == "source - no compatibility guaranteed"
|| artifact_version.is_empty()
{
log::warn!("Artifact version is 0.0.0, skipping version check");
return;
}
let version = crate::version();
if version == "source - no compatibility guaranteed" {
log::warn!("Compiled source version is not guaranteed to match artifact version");
return;
}
if version != artifact_version {
log::warn!(
"Version mismatch: CLI version is {} but artifact version is {}",
version,
artifact_version
);
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_valid_default_args() {
let args = RunArgs::default();
assert!(args.validate().is_ok());
}
#[test]
fn test_invalid_param_visibility() {
let mut args = RunArgs::default();
args.param_visibility = Visibility::Public;
let err = args.validate().unwrap_err();
assert!(err.contains("Parameters cannot be public instances"));
}
#[test]
fn test_invalid_scale_rebase() {
let mut args = RunArgs::default();
args.scale_rebase_multiplier = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
}
#[test]
fn test_invalid_lookup_range() {
let mut args = RunArgs::default();
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
assert!(err.contains("Invalid lookup range"));
}
#[test]
fn test_invalid_logrows() {
let mut args = RunArgs::default();
args.logrows = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("logrows must be >= 1"));
}
#[test]
fn test_invalid_inner_cols() {
let mut args = RunArgs::default();
args.num_inner_cols = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("num_inner_cols must be >= 1"));
}
#[test]
fn test_invalid_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = 1.0;
args.output_visibility = Visibility::Private;
let err = args.validate().unwrap_err();
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
}
#[test]
fn test_negative_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = -1.0;
let err = args.validate().unwrap_err();
assert!(err.contains("tolerance cannot be negative"));
}
#[test]
fn test_zero_batch_size() {
let mut args = RunArgs::default();
args.variables = vec![("batch_size".to_string(), 0)];
let err = args.validate().unwrap_err();
assert!(err.contains("'batch_size' cannot be 0"));
}
#[test]
fn test_json_serialization() {
let args = RunArgs::default();
let json = args.as_json().unwrap();
let deserialized = RunArgs::from_json(&json).unwrap();
assert_eq!(args, deserialized);
}
#[test]
fn test_multiple_validation_errors() {
let mut args = RunArgs::default();
args.logrows = 0;
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
// Should contain multiple error messages
assert!(err.matches("\n").count() >= 1);
}
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
}

Some files were not shown because too many files have changed in this diff Show More