Compare commits

..

15 Commits

Author SHA1 Message Date
dante
9bbc89cc89 chore: bump h2 2024-08-19 17:44:46 -04:00
dante
28b65f2639 chore: bump halo2proofs 2024-08-19 17:38:22 -04:00
dante
9592d38a8f chore: add asm feature 2024-08-19 17:33:27 -04:00
dante
2cec49dfc3 chore: bump fft algos 2024-08-18 20:50:10 -04:00
dante
31a1681ca4 chore: update h2 curves 2024-08-18 13:36:00 -04:00
dante
134b54d32b Update Cargo.toml 2024-08-15 12:42:45 -04:00
dante
beb5f12376 chore: use mimalloc 2024-08-15 12:40:55 -04:00
dante
65be3c84bb Update Cargo.toml 2024-08-14 18:05:26 -04:00
dante
6f743c57d3 chore: parallelize prepare and commit 2024-08-13 15:06:47 -04:00
dante
ddb54c5a73 feat: precompute lookup cosets 2024-08-08 18:15:22 -04:00
dante
6e1f22a15b log lack of cache 2024-08-08 10:44:41 -04:00
dante
da97323bde feat: cache lookup tables 2024-08-08 09:12:40 -04:00
dante
55046feeb6 Update Cargo.toml 2024-08-07 23:42:19 -04:00
dante
d0d0596e58 chore: bump h2 2024-08-07 23:40:42 -04:00
dante
b78efdcbf4 fix: add required serde patches 2024-08-07 18:30:08 -04:00
46 changed files with 3418 additions and 2497 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish EZKL Engine npm package
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
@@ -62,7 +62,7 @@ jobs:
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
@@ -79,10 +79,6 @@ jobs:
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
- name: Add serialize and deserialize methods to nodejs bundle
run: |
echo '
@@ -178,3 +174,40 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -128,7 +128,6 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -360,17 +359,3 @@ jobs:
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}

View File

@@ -307,8 +307,8 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -380,7 +380,7 @@ jobs:
cache: "pnpm"
- name: Install dependencies for js tests
run: |
pnpm install --frozen-lockfile
pnpm install --no-frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -610,24 +610,6 @@ jobs:
python-integration-tests:
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
@@ -652,8 +634,6 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
@@ -671,3 +651,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -14,40 +14,6 @@ jobs:
- uses: actions/checkout@v4
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2
uses: mathieudutour/github-tag-action@v6.1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Set Cargo.toml version to match github tag for docs
shell: bash
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
mv docs/python/src/conf.py docs/python/src/conf.py.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
rm docs/python/src/conf.py.orig
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
rm docs/python/requirements-docs.txt.orig
- name: Commit files and create tag
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git fetch --tags
git checkout -b release-$RELEASE_TAG
git add .
git commit -m "ci: update version string in docs"
git tag -d $RELEASE_TAG
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:
branch: release-${{ steps.tag_version.outputs.new_tag }}
force: true
tags: true

View File

@@ -1,54 +0,0 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

1
.gitignore vendored
View File

@@ -49,5 +49,4 @@ node_modules
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key

View File

@@ -1 +0,0 @@
3.12.1

View File

@@ -1,26 +0,0 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: ./docs/python/src/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: ./docs/python/requirements-docs.txt

164
Cargo.lock generated
View File

@@ -644,6 +644,15 @@ dependencies = [
"wyz",
]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest 0.10.7",
]
[[package]]
name = "blake2b_simd"
version = "1.0.2"
@@ -1185,6 +1194,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "deranged"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
dependencies = [
"powerfmt",
]
[[package]]
name = "derivative"
version = "2.2.0"
@@ -1307,6 +1325,12 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
[[package]]
name = "dyn-hash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
[[package]]
name = "ecc"
version = "0.1.0"
@@ -1785,7 +1809,7 @@ dependencies = [
"halo2_gadgets",
"halo2_proofs",
"halo2_solidity_verifier",
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
"halo2curves 0.7.0",
"hex",
"indicatif",
"instant",
@@ -1793,6 +1817,7 @@ dependencies = [
"lazy_static",
"log",
"maybe-rayon",
"mimalloc",
"mnist",
"num",
"openssl",
@@ -2176,10 +2201,11 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
version = "2.2.1"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
"num-traits",
]
@@ -2204,19 +2230,23 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a"
source = "git+https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03"
dependencies = [
"bincode",
"blake2b_simd",
"env_logger",
"ff",
"group",
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
"halo2curves 0.7.0",
"icicle",
"lazy_static",
"log",
"maybe-rayon",
"rand_chacha",
"rand_core 0.6.4",
"rustacuda",
"rustc-hash",
"serde",
"sha3 0.9.1",
"tracing",
]
@@ -2224,7 +2254,7 @@ dependencies = [
[[package]]
name = "halo2_solidity_verifier"
version = "0.1.0"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#eb04be1f7d005e5b9dd3ff41efa30aeb5e0c34a3"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#3082fda94151fc6760a3cb2be4741ddbeef04c03"
dependencies = [
"askama",
"blake2b_simd",
@@ -2300,15 +2330,18 @@ dependencies = [
[[package]]
name = "halo2curves"
version = "0.6.1"
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c#9fff22c5f72cc54fac1ef3a844e1072b08cfecdf"
version = "0.7.0"
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
dependencies = [
"blake2b_simd",
"blake2",
"digest 0.10.7",
"ff",
"group",
"halo2derive",
"hex",
"lazy_static",
"num-bigint",
"num-integer",
"num-traits",
"pairing",
"pasta_curves",
@@ -2318,11 +2351,25 @@ dependencies = [
"rayon",
"serde",
"serde_arrays",
"sha2",
"static_assertions",
"subtle",
"unroll",
]
[[package]]
name = "halo2derive"
version = "0.1.0"
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "halo2wrong"
version = "0.1.0"
@@ -2830,6 +2877,16 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libmimalloc-sys"
version = "0.1.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "libredox"
version = "0.0.1"
@@ -2993,6 +3050,15 @@ dependencies = [
"autocfg",
]
[[package]]
name = "mimalloc"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633"
dependencies = [
"libmimalloc-sys",
]
[[package]]
name = "mime"
version = "0.3.17"
@@ -3126,6 +3192,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
@@ -3714,6 +3786,12 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@@ -4058,9 +4136,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.9.0"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
@@ -4360,6 +4438,12 @@ version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc-hash"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
[[package]]
name = "rustc-hex"
version = "2.1.0"
@@ -4782,11 +4866,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "snark-verifier"
version = "0.1.1"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#574b65ea6b4d43eebac5565146519a95b435815c"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
dependencies = [
"ecc",
"halo2_proofs",
"halo2curves 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
"halo2curves 0.6.1",
"hex",
"itertools 0.10.5",
"lazy_static",
@@ -5127,11 +5211,14 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.23"
version = "0.3.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
@@ -5139,16 +5226,17 @@ dependencies = [
[[package]]
name = "time-core"
version = "0.1.1"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.10"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
dependencies = [
"num-conv",
"time-core",
]
@@ -5398,8 +5486,8 @@ dependencies = [
[[package]]
name = "tract-core"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"anyhow",
"bit-set",
@@ -5422,11 +5510,14 @@ dependencies = [
[[package]]
name = "tract-data"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"anyhow",
"half 2.2.1",
"downcast-rs",
"dyn-clone",
"dyn-hash",
"half 2.4.1",
"itertools 0.12.1",
"lazy_static",
"maplit",
@@ -5441,8 +5532,8 @@ dependencies = [
[[package]]
name = "tract-hir"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"derive-new",
"log",
@@ -5451,20 +5542,23 @@ dependencies = [
[[package]]
name = "tract-linalg"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"byteorder",
"cc",
"derive-new",
"downcast-rs",
"dyn-clone",
"half 2.2.1",
"dyn-hash",
"half 2.4.1",
"lazy_static",
"liquid",
"liquid-core",
"log",
"num-traits",
"paste",
"rayon",
"scan_fmt",
"smallvec",
"time",
@@ -5475,8 +5569,8 @@ dependencies = [
[[package]]
name = "tract-nnef"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"byteorder",
"flate2",
@@ -5489,8 +5583,8 @@ dependencies = [
[[package]]
name = "tract-onnx"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"bytes",
"derive-new",
@@ -5506,8 +5600,8 @@ dependencies = [
[[package]]
name = "tract-onnx-opl"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
version = "0.21.6-pre"
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
dependencies = [
"getrandom",
"log",

View File

@@ -15,9 +15,10 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
mimalloc = "0.1"
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
@@ -32,10 +33,10 @@ log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
@@ -80,7 +81,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
@@ -175,7 +176,7 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
default = ["ezkl", "mv-lookup", "precompute-coset"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = [
@@ -194,6 +195,8 @@ mv-lookup = [
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
@@ -204,8 +207,8 @@ no-banner = []
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
[profile.release]
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
padding: [(0, 0); 2],
stride: (1, 1),
}),
)
.unwrap();

View File

@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],
kernel_shape: vec![2, 2],
padding: [(0, 0); 2],
stride: (1, 1),
kernel_shape: (2, 2),
normalized: false,
}),
)

View File

@@ -1,2 +0,0 @@
#!/bin/sh
sphinx-build ./src build

View File

@@ -1,4 +0,0 @@
ezkl==10.3.3
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,29 +0,0 @@
import ezkl
project = 'ezkl'
release = '10.3.3'
version = release
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
]
autosummary_generate = True
autosummary_imported_members = True
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']

View File

@@ -1,11 +0,0 @@
.. extension documentation master file, created by
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
ezkl python bindings
================================================
.. automodule:: ezkl
:members:
:undoc-members:

View File

@@ -203,8 +203,8 @@ where
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
padding: [(PADDING, PADDING); 2],
stride: (STRIDE, STRIDE),
};
let x = config
.layer_config

View File

@@ -7,9 +7,9 @@
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"Make sure you install postgres if needed https://postgresapp.com/. \n",
"\n"
]
},
@@ -21,81 +21,23 @@
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
"os.system(\"chmod +x e2pg\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"os.system(\"echo $RLPS_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(5)\n",
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
"\n"
]
},
@@ -137,13 +79,11 @@
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
"# import logging\n",
"# # # uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)"
]
},
{
@@ -236,7 +176,6 @@
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
@@ -244,9 +183,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"dbname\": \"e2pg\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -255,7 +194,7 @@
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
@@ -271,9 +210,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"dbname\": \"e2pg\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -290,6 +229,22 @@
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -298,21 +253,10 @@
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"# setup kzg params\n",
"params_path = os.path.join('kzg.params')\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
"res = ezkl.get_srs(params_path, settings_filename)"
]
},
{
@@ -362,13 +306,16 @@
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"params_path = os.path.join('kzg.params')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" pk_path,\n",
" params_path,\n",
" settings_filename,\n",
" )\n",
"\n",
"assert res == True\n",
@@ -384,14 +331,11 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
@@ -416,14 +360,73 @@
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" params_path,\n",
" \"single\",\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
"\n",
"# verify\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_filename,\n",
" vk_path,\n",
" params_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7tAa-DFAtvS"
},
"source": [
"# Part 2 (Using the ZK Computational Graph Onchain!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Ym91kaVAIB6"
},
"source": [
"**Now How Do We Do It Onchain?????**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
"print(params_path)\n",
"print(settings_filename)\n",
"\n",
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" params_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
@@ -432,8 +435,51 @@
"metadata": {},
"outputs": [],
"source": [
"# kill all shovel process \n",
"os.system(\"pkill -f shovel\")"
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.system(\"killall -9 e2pg\");"
]
}
],
@@ -455,7 +501,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.13"
}
},
"nbformat": 4,

View File

@@ -20,16 +20,16 @@
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "4.0.0",
"@ethereumjs/evm": "2.0.0",
"@ethereumjs/statemanager": "2.0.0",
"@ethereumjs/tx": "5.0.0",
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "6.7.1",
"json-bigint": "1.0.0"
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",

View File

@@ -6,34 +6,34 @@ settings:
dependencies:
'@ethereumjs/common':
specifier: 4.0.0
specifier: ^4.0.0
version: 4.0.0
'@ethereumjs/evm':
specifier: 2.0.0
specifier: ^2.0.0
version: 2.0.0
'@ethereumjs/statemanager':
specifier: 2.0.0
specifier: ^2.0.0
version: 2.0.0
'@ethereumjs/tx':
specifier: 5.0.0
specifier: ^5.0.0
version: 5.0.0
'@ethereumjs/util':
specifier: 9.0.0
specifier: ^9.0.0
version: 9.0.0
'@ethereumjs/vm':
specifier: 7.0.0
specifier: ^7.0.0
version: 7.0.0
'@ethersproject/abi':
specifier: 5.7.0
specifier: ^5.7.0
version: 5.7.0
'@ezkljs/engine':
specifier: ^9.4.4
version: 9.4.4
ethers:
specifier: 6.7.1
specifier: ^6.7.1
version: 6.7.1
json-bigint:
specifier: 1.0.0
specifier: ^1.0.0
version: 1.0.0
devDependencies:

View File

@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
exit 1
fi
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=1.0,<2.0"]
requires = ["maturin>=0.14,<0.15"]
build-backend = "maturin"
[tool.pytest.ini_options]

View File

@@ -2,7 +2,7 @@ attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.5.1
maturin==1.5.0
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
@@ -11,4 +11,4 @@ typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4
numpy==1.26.4

View File

@@ -1,4 +1,6 @@
// ignore file if compiling for wasm
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(not(target_arch = "wasm32"))]
use clap::Parser;

View File

@@ -956,6 +956,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
op.layout(self, region, values)
let res = op.layout(self, region, values)?;
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
if let Some(claimed_output) = &res {
// during key generation this will be unknown vals so we use this as a flag to check
let mut is_assigned = !claimed_output.any_unknowns()?;
for val in values.iter() {
is_assigned = is_assigned && !val.any_unknowns()?;
}
if is_assigned {
op.safe_mode_check(claimed_output, values)?;
}
}
};
Ok(res)
}
}

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::i128_to_felt,
fieldutils::{felt_to_i128, i128_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -29,15 +29,15 @@ pub enum HybridOp {
dim: usize,
},
SumPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
kernel_shape: Vec<usize>,
padding: [(usize, usize); 2],
stride: (usize, usize),
kernel_shape: (usize, usize),
normalized: bool,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
MaxPool2d {
padding: [(usize, usize); 2],
stride: (usize, usize),
pool_dims: (usize, usize),
},
ReduceMin {
axes: Vec<usize>,
@@ -85,6 +85,93 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => tensor::ops::nonlinearities::softmax_axes(
&x,
input_scale.into(),
output_scale.into(),
axes,
),
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater(&x, &y)?
}
HybridOp::GreaterEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater_equal(&x, &y)?
}
HybridOp::Less => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less(&x, &y)?
}
HybridOp::LessEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less_equal(&x, &y)?
}
HybridOp::Equals => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::equals(&x, &y)?
}
};
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
match self {
@@ -114,12 +201,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
HybridOp::MaxPool {
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
@@ -166,9 +253,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
padding,
stride,
kernel_shape,
*padding,
*stride,
*kernel_shape,
*normalized,
)?,
HybridOp::Recip {
@@ -228,17 +315,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
}
}
HybridOp::MaxPool {
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
} => layouts::max_pool(
} => layouts::max_pool2d(
config,
region,
values[..].try_into()?,
padding,
stride,
pool_dims,
*padding,
*stride,
*pool_dims,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?

File diff suppressed because it is too large Load Diff

View File

@@ -136,11 +136,61 @@ impl LookupOp {
(-range, range)
}
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Abs => "abs".into(),
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
LookupOp::Floor { scale } => format!("floor_{}", scale),
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
LookupOp::Sign => "sign".into(),
LookupOp::LessThan { a } => format!("less_than_{}", a),
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::ReLU => "relu".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
LookupOp::Sin { scale } => format!("sin_{}", scale),
LookupOp::ASin { scale } => format!("asin_{}", scale),
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
LookupOp::Tan { scale } => format!("tan_{}", scale),
LookupOp::ATan { scale } => format!("atan_{}", scale),
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
}
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
@@ -232,13 +282,6 @@ impl LookupOp {
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Returns the name of the operation
fn as_string(&self) -> String {

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -35,6 +35,8 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
fn as_string(&self) -> String;
@@ -69,6 +71,33 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any;
/// Safe mode output checl
fn safe_mode_check(
&self,
claimed_output: &ValTensor<F>,
original_values: &[ValTensor<F>],
) -> Result<(), TensorError> {
let felt_evals = original_values
.iter()
.map(|v| {
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
evals.reshape(v.dims())?;
Ok(evals)
})
.collect::<Result<Vec<_>, _>>()?;
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
let mut output = claimed_output
.get_felt_evals()
.map_err(|_| TensorError::FeltError)?;
output.reshape(claimed_output.dims())?;
assert_eq!(output, ref_op);
Ok(())
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
@@ -147,6 +176,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
self
}
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
})
}
fn as_string(&self) -> String {
"Input".into()
}
@@ -200,6 +235,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Err(TensorError::WrongMethod)
}
fn as_string(&self) -> String {
"Unknown".into()
@@ -269,6 +307,11 @@ impl<
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
format!("CONST (scale={})", self.quantized_values.scale().unwrap())

View File

@@ -1,5 +1,6 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -31,8 +32,8 @@ pub enum PolyOp {
equation: String,
},
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
padding: [(usize, usize); 2],
stride: (usize, usize),
},
Downsample {
axis: usize,
@@ -40,9 +41,9 @@ pub enum PolyOp {
modulo: usize,
},
DeConv {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
padding: [(usize, usize); 2],
output_padding: (usize, usize),
stride: (usize, usize),
},
Add,
Sub,
@@ -57,13 +58,10 @@ pub enum PolyOp {
destination: usize,
},
Flatten(Vec<usize>),
Pad(Vec<(usize, usize)>),
Pad([(usize, usize); 2]),
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
@@ -107,28 +105,10 @@ impl<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -140,26 +120,15 @@ impl<
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
PolyOp::Pad(_) => "PAD".into(),
PolyOp::Add => "ADD".into(),
PolyOp::Mult => "MULT".into(),
PolyOp::Sub => "SUB".into(),
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
)
}
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
@@ -173,6 +142,146 @@ impl<
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
"multibroadcastto inputs".to_string(),
));
}
inputs[0].expand(shape)
}
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
PolyOp::Not => tensor::ops::not(&inputs[0]),
PolyOp::Downsample {
axis,
stride,
modulo,
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::MoveAxis {
source,
destination,
} => inputs[0].move_axis(*source, *destination),
PolyOp::Flatten(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::Pad(p) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pad inputs".to_string()));
}
tensor::ops::pad(&inputs[0], *p)
}
PolyOp::Add => tensor::ops::add(&inputs),
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult => tensor::ops::mult(&inputs),
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
PolyOp::DeConv {
padding,
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
}
inputs[0].pow(*u)
}
PolyOp::Sum { axes } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("sum inputs".to_string()));
}
tensor::ops::sum_axes(&inputs[0], axes)
}
PolyOp::Prod { axes, .. } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("prod inputs".to_string()));
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
PolyOp::Slice { axis, start, end } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
}?;
Ok(ForwardResult { output: res })
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
@@ -183,9 +292,6 @@ impl<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
@@ -212,7 +318,7 @@ impl<
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -264,9 +370,9 @@ impl<
config,
region,
values[..].try_into()?,
padding,
output_padding,
stride,
*padding,
*output_padding,
*stride,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -282,7 +388,7 @@ impl<
)));
}
let mut input = values[0].clone();
input.pad(p.clone(), 0)?;
input.pad(*p)?;
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
@@ -298,7 +404,6 @@ impl<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {

View File

@@ -17,6 +17,8 @@ use crate::{
use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
@@ -25,6 +27,13 @@ pub const RANGE_MULTIPLIER: i128 = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
lazy_static::lazy_static! {
/// an optional directory to read and write the lookup table cache
static ref LOOKUP_CACHE: Option<std::path::PathBuf> = std::env::var("LOOKUP_CACHE")
.ok()
.map(std::path::PathBuf::from);
}
#[derive(Debug, Clone)]
///
pub struct SelectorConstructor<F: PrimeField> {
@@ -111,10 +120,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
let chunk = chunk as i128;
// we index from 1 to prevent soundness issues
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
.unwrap();
let op_f = Op::<F>::f(
&self.nonlinearity,
&[Tensor::from(vec![first_element].into_iter())],
)
.unwrap();
(first_element, op_f.output[0])
}
@@ -202,8 +212,46 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(i128_to_felt(x)))?;
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
Ok((inputs, evals.output))
};
let (inputs, evals) = if let Some(cache) = &*LOOKUP_CACHE {
let cache_path = cache.join(self.nonlinearity.as_path());
let input_path = cache_path.join("inputs");
let output_path = cache_path.join("outputs");
if cache_path.exists() {
log::info!("Loading lookup table from cache: {:?}", cache_path);
let (input_cache, output_cache) =
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
(input_cache, output_cache)
} else {
log::info!(
"Generating lookup table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path)?;
let (inputs, evals) = gen_table()?;
inputs.save(&input_path)?;
evals.save(&output_path)?;
(inputs, evals)
}
} else {
log::info!(
"Generating lookup table {} without cache",
self.nonlinearity.as_path()
);
gen_table()?
};
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -235,7 +283,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
)?;
}
let output = evals.output[row_offset];
let output = evals[row_offset];
table.assign_cell(
|| format!("nl_o_col row {}", row_offset),
@@ -273,6 +321,11 @@ pub struct RangeCheck<F: PrimeField> {
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// as path
pub fn as_path(&self) -> String {
format!("rangecheck_{}_{}", self.range.0, self.range.1)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
@@ -350,7 +403,31 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = if let Some(cache) = &*LOOKUP_CACHE {
let cache_path = cache.join(self.as_path());
let input_path = cache_path.join("inputs");
if cache_path.exists() {
log::info!("Loading range check table from cache: {:?}", cache_path);
Tensor::load(&input_path)?
} else {
log::info!(
"Generating range check table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path)?;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
inputs.save(&input_path)?;
inputs
}
} else {
log::info!("Generating range check {} without cache", self.as_path());
Tensor::from(smallest..=largest).map(|x| i128_to_felt(x))
};
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -1048,8 +1048,8 @@ mod conv {
&mut region,
&self.inputs,
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
padding: [(1, 1); 2],
stride: (2, 2),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
padding: [(1, 1); 2],
stride: (2, 2),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
padding: [(1, 1); 2],
stride: (2, 2),
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -455,7 +455,7 @@ pub async fn verify_proof_with_data_attestation(
for val in flattened_instances.clone() {
let bytes = val.to_repr();
let u = U256::from_little_endian(bytes.as_slice());
let u = U256::from_little_endian(bytes.inner());
public_inputs.push(u);
}

View File

@@ -196,6 +196,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
srs_path,
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
#[cfg(not(target_arch = "wasm32"))]
@@ -636,7 +637,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
Ok(String::new())
}
pub(crate) fn gen_witness(
pub(crate) async fn gen_witness(
compiled_circuit_path: PathBuf,
data: PathBuf,
output: Option<PathBuf>,
@@ -659,7 +660,7 @@ pub(crate) fn gen_witness(
};
#[cfg(not(target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data)?;
let mut input = circuit.load_graph_input(&data).await?;
#[cfg(target_arch = "wasm32")]
let mut input = circuit.load_graph_input(&data)?;

View File

@@ -21,6 +21,8 @@ use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
@@ -232,15 +234,21 @@ impl PostgresSource {
)
};
let mut client = Client::connect(&config, NoTls)?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[])? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
let mut client = Client::connect(&config, NoTls).unwrap();
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).unwrap() {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
}
}
res
})
.join()
.map_err(|_| "failed to fetch data from postgres")?;
Ok(vec![res])
}

View File

@@ -918,7 +918,7 @@ impl GraphCircuit {
///
#[cfg(not(target_arch = "wasm32"))]
pub fn load_graph_input(
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
@@ -928,6 +928,7 @@ impl GraphCircuit {
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(target_arch = "wasm32")]
@@ -951,7 +952,7 @@ impl GraphCircuit {
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
fn process_data_source(
async fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
@@ -964,16 +965,8 @@ impl GraphCircuit {
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
// start runtime and fetch data
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async {
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
})
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
}
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)

View File

@@ -1200,20 +1200,6 @@ impl Model {
.collect();
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
@@ -1225,11 +1211,25 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("dims: {:?}", node.out_dims());
debug!(
"input dims {:?}",
"input_dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {

View File

@@ -14,6 +14,7 @@ use crate::circuit::Op;
use crate::circuit::Unknown;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
@@ -60,6 +61,20 @@ impl Op<Fp> for Rescaled {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
if self.scale.len() != x.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
}
let mut rescaled_inputs = vec![];
let inputs = &mut x.to_vec();
for (i, ri) in inputs.iter_mut().enumerate() {
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
let res = (ri.clone() * mult_tensor)?;
rescaled_inputs.push(res);
}
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
}
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
@@ -200,6 +215,13 @@ impl Op<Fp> for RebaseScale {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
Ok(res)
}
fn as_string(&self) -> String {
format!(
@@ -367,6 +389,13 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
fn f(
&self,
inputs: &[Tensor<Fp>],
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
self.as_op().f(inputs)
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,

View File

@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -734,19 +734,6 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Sum { axes })
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
}
"Max" => {
// Extract the max value
// first find the input that is a constant
@@ -1119,7 +1106,17 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1127,10 +1124,26 @@ pub fn new_op_from_onnx(
};
let kernel_shape = &pool_spec.kernel_shape;
SupportedOp::Hybrid(HybridOp::MaxPool {
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
};
SupportedOp::Hybrid(HybridOp::MaxPool2d {
padding,
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
stride: (stride_h, stride_w),
pool_dims: (kernel_height, kernel_width),
})
}
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
@@ -1152,7 +1165,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
deleted_indices.push(inputs.len() - 1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
}
@@ -1192,7 +1205,15 @@ pub fn new_op_from_onnx(
}
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
Some(s) => {
if s.len() == 1 {
(s[0], s[0])
} else if s.len() == 2 {
(s[0], s[1])
} else {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
}
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
@@ -1200,7 +1221,17 @@ pub fn new_op_from_onnx(
let padding = match &conv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1255,20 +1286,33 @@ pub fn new_op_from_onnx(
}
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
Some(s) => (s[0], s[1]),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
};
let padding = match &deconv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let output_padding: (usize, usize) =
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
@@ -1287,7 +1331,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::DeConv {
padding,
output_padding: deconv_node.adjustments.to_vec(),
output_padding,
stride,
})
}
@@ -1388,17 +1432,46 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let kernel_shape = &pool_spec.kernel_shape;
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams(
"kernel shape".to_string(),
)));
};
SupportedOp::Hybrid(HybridOp::SumPool {
padding,
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
stride: (stride_h, stride_w),
kernel_shape: (kernel_height, kernel_width),
normalized: sumpool_node.normalize,
})
}
@@ -1425,7 +1498,29 @@ pub fn new_op_from_onnx(
)));
}
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
let padding_len = pad_node.pads.len();
// we only support symmetrical padding that affects the last 2 dims (height and width params)
for (i, pad_params) in pad_node.pads.iter().enumerate() {
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
return Err(Box::new(GraphError::MisformedParams(
"ezkl currently only supports padding height and width dimensions"
.to_string(),
)));
}
}
let padding = [
(
pad_node.pads[padding_len - 2].0,
pad_node.pads[padding_len - 1].0,
),
(
pad_node.pads[padding_len - 2].1,
pad_node.pads[padding_len - 1].1,
),
];
SupportedOp::Linear(PolyOp::Pad(padding))
}
"RmAxis" | "Reshape" | "AddAxis" => {
// Extract the slope layer hyperparams

View File

@@ -23,7 +23,7 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(buf_read_has_data_left)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!

View File

@@ -76,7 +76,7 @@ pub fn init_logger() {
prefix_token(&record.level()),
// pretty print UTC time
chrono::Utc::now()
.format("%Y-%m-%d %H:%M:%S")
.format("%Y-%m-%d %H:%M:%S:%3f")
.to_string()
.bright_magenta(),
record.metadata().target(),

View File

@@ -550,7 +550,8 @@ where
+ PrimeField
+ FromUniformBytes<64>
+ WithSmallOrderMulGroup<3>,
Scheme::Curve: Serialize + DeserializeOwned,
Scheme::Curve: Serialize + DeserializeOwned + SerdeObject,
Scheme::ParamsProver: Send + Sync,
{
let strategy = Strategy::new(params.verifier_params());
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);

View File

@@ -33,7 +33,6 @@ use tokio::runtime::Runtime;
type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
enum PyTestDataSource {
@@ -52,17 +51,14 @@ impl From<PyTestDataSource> for TestDataSource {
}
}
/// pyclass containing the struct used for G1, this is mostly a helper class
/// pyclass containing the struct used for G1
#[pyclass]
#[derive(Debug, Clone)]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
x: PyFelt,
#[pyo3(get, set)]
/// Field Element representing y
y: PyFelt,
/// Field Element representing y
#[pyo3(get, set)]
z: PyFelt,
}
@@ -138,59 +134,39 @@ impl pyo3::ToPyObject for PyG1Affine {
}
}
/// Python class containing the struct used for run_args
///
/// Returns
/// -------
/// PyRunArgs
///
/// pyclass containing the struct used for run_args
#[pyclass]
#[derive(Clone)]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing parameters
pub param_scale: crate::Scale,
#[pyo3(get, set)]
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
pub scale_rebase_multiplier: u32,
#[pyo3(get, set)]
/// list[int]: The min and max elements in the lookup table input column
pub lookup_range: crate::circuit::table::Range,
#[pyo3(get, set)]
/// int: The log_2 number of rows
pub logrows: u32,
#[pyo3(get, set)]
/// int: The number of inner columns used for the lookup table
pub num_inner_cols: usize,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub input_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub output_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub param_visibility: Visibility,
#[pyo3(get, set)]
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
/// str: check mode, accepts `safe`, `unsafe`
pub check_mode: CheckMode,
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
}
@@ -250,7 +226,7 @@ impl Into<PyRunArgs> for RunArgs {
#[pyclass]
#[derive(Debug, Clone)]
/// pyclass representing an enum, denoting the type of commitment
/// Pyclass marking the type of commitment
pub enum PyCommitments {
/// KZG commitment
KZG,
@@ -297,19 +273,7 @@ impl FromStr for PyCommitments {
}
}
/// Converts a field element hex string to big endian
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
///
/// Returns
/// -------
/// str
/// field element represented as a string
///
/// Converts a felt to big endian
#[pyfunction(signature = (
felt,
))]
@@ -319,45 +283,22 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
}
/// Converts a field element hex string to an integer
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// Returns
/// -------
/// int
///
#[pyfunction(signature = (
felt,
array,
))]
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts a field element hex string to a floating point number
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to convert the field element into a floating point representation
///
/// Returns
/// -------
/// float
///
/// Converts a field eleement hex string to a floating point number
#[pyfunction(signature = (
felt,
array,
scale
))]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
@@ -365,23 +306,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
}
/// Converts a floating point element to a field element hex string
///
/// Arguments
/// -------
/// input: float
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// Returns
/// -------
/// str
/// The field element represented as a string
///
#[pyfunction(signature = (
input,
scale
input,
scale
))]
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
@@ -391,20 +318,9 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
}
/// Converts a buffer to vector of field elements
///
/// Arguments
/// -------
/// buffer: list[int]
/// List of integers representing a buffer
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
buffer
))]
))]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -463,20 +379,9 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
}
/// Generate a poseidon hash.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
message,
))]
))]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
@@ -497,31 +402,12 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
}
/// Generate a kzg commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -556,31 +442,12 @@ fn kzg_commit(
}
/// Generate an ipa commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -615,19 +482,10 @@ fn ipa_commit(
}
/// Swap the commitments in a proof
///
/// Arguments
/// -------
/// proof_path: str
/// Path to the proof file
///
/// witness_path: str
/// Path to the witness file
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
))]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -636,27 +494,11 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
}
/// Generates a vk from a pk for a model circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// circuit_settings_path: str
/// Path to the witness file
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK),
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
))]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -678,22 +520,10 @@ fn gen_vk_from_pk_single(
}
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
))]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -708,17 +538,6 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
}
/// Displays the table as a string in python
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// Returns
/// ---------
/// str
/// Table of the nodes in the onnx file
///
#[pyfunction(signature = (
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
@@ -734,16 +553,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
}
}
/// Generates the Structured Reference String (SRS), use this only for testing purposes
///
/// Arguments
/// ---------
/// srs_path: str
/// Path to the create the SRS file
///
/// logrows: int
/// The number of logrows for the SRS file
///
/// generates the srs
#[pyfunction(signature = (
srs_path,
logrows,
@@ -754,26 +564,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
Ok(())
}
/// Gets a public srs
///
/// Arguments
/// ---------
/// settings_path: str
/// Path to the settings file
///
/// logrows: int
/// The number of logrows for the SRS file
///
/// srs_path: str
/// Path to the create the SRS file
///
/// commitment: str
/// Specify the commitment used ("kzg", "ipa")
///
/// Returns
/// -------
/// bool
///
/// gets a public srs
#[pyfunction(signature = (
settings_path=PathBuf::from(DEFAULT_SETTINGS),
logrows=None,
@@ -806,23 +597,7 @@ fn get_srs(
Ok(true)
}
/// Generates the circuit settings
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the settings file
///
/// py_run_args: PyRunArgs
/// PyRunArgs object to initialize the settings
///
/// Returns
/// -------
/// bool
///
/// generates the circuit settings
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
@@ -843,38 +618,7 @@ fn gen_settings(
Ok(true)
}
/// Calibrates the circuit settings
///
/// Arguments
/// ---------
/// data: str
/// Path to the calibration data
///
/// model: str
/// Path to the onnx file
///
/// settings: str
/// Path to the settings file
///
/// lookup_safety_margin: int
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
///
/// scales: list[int]
/// Optional scales to specifically try for calibration
///
/// scale_rebase_multiplier: list[int]
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
///
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
/// bool
///
/// calibrates the circuit settings
#[pyfunction(signature = (
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
@@ -916,30 +660,7 @@ fn calibrate_settings(
Ok(true)
}
/// Runs the forward pass operation to generate a witness
///
/// Arguments
/// ---------
/// data: str
/// Path to the data file
///
/// model: str
/// Path to the compiled model file
///
/// output: str
/// Path to create the witness file
///
/// vk_path: str
/// Path to the verification key
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// dict
/// Python object containing the witness values
///
/// runs the forward pass operation
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_MODEL),
@@ -954,28 +675,19 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<PyObject> {
let output =
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
let output = Runtime::new()
.unwrap()
.block_on(crate::execute::gen_witness(
model, data, output, vk_path, srs_path,
))
.map_err(|e| {
let err_str = format!("Failed to run generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Mocks the prover
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// Returns
/// -------
/// bool
///
/// mocks the prover
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -988,23 +700,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
Ok(true)
}
/// Mocks the aggregate prover
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the relevant proof files
///
/// logrows: int
/// Number of logrows to use for the aggregation circuit
///
/// split_proofs: bool
/// Indicates whether the accumulated are segments of a larger proof
///
/// Returns
/// -------
/// bool
///
/// mocks the aggregate prover
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
@@ -1023,32 +719,7 @@ fn mock_aggregate(
Ok(true)
}
/// Runs the setup process
///
/// Arguments
/// ---------
/// model: str
/// Path to the compiled model file
///
/// vk_path: str
/// Path to create the verification key file
///
/// pk_path: str
/// Path to create the proving key file
///
/// srs_path: str
/// Path to the SRS file
///
/// witness_path: str
/// Path to the witness file
///
/// disable_selector_compression: bool
/// Whether to compress the selectors or not
///
/// Returns
/// -------
/// bool
///
/// runs the prover on a set of inputs
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -1081,32 +752,7 @@ fn setup(
Ok(true)
}
/// Runs the prover on a set of inputs
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// pk_path: str
/// Path to the proving key file
///
/// proof_path: str
/// Path to create the proof file
///
/// proof_type: str
/// Accepts `single`, `for-aggr`
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// bool
///
/// runs the prover on a set of inputs
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -1140,29 +786,7 @@ fn prove(
Python::with_gil(|py| Ok(snark.to_object(py)))
}
/// Verifies a given proof
///
/// Arguments
/// ---------
/// proof_path: str
/// Path to create the proof file
///
/// settings_path: str
/// Path to the settings file
///
/// vk_path: str
/// Path to the verification key file
///
/// srs_path: str
/// Path to the SRS file
///
/// non_reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// Returns
/// -------
/// bool
///
/// verifies a given proof
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -1192,38 +816,6 @@ fn verify(
Ok(true)
}
/// Runs the setup process for an aggregate setup
///
/// Arguments
/// ---------
/// sample_snarks: list[str]
/// List of paths to the various proofs
///
/// vk_path: str
/// Path to create the aggregated VK
///
/// pk_path: str
/// Path to create the aggregated PK
///
/// logrows: int
/// Number of logrows to use
///
/// split_proofs: bool
/// Whether the accumulated are segments of a larger proof
///
/// srs_path: str
/// Path to the SRS file
///
/// disable_selector_compression: bool
/// Whether to compress selectors
///
/// commitment: str
/// Accepts `kzg`, `ipa`
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
@@ -1262,23 +854,6 @@ fn setup_aggregate(
Ok(true)
}
/// Compiles the circuit for use in other steps
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx model file
///
/// compiled_circuit: str
/// Path to output the compiled circuit
///
/// settings_path: str
/// Path to the settings files
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
@@ -1297,41 +872,7 @@ fn compile_circuit(
Ok(true)
}
/// Creates an aggregated proof
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the various proofs
///
/// proof_path: str
/// Path to output the aggregated proof
///
/// vk_path: str
/// Path to the VK file
///
/// transcript:
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
///
/// logrows:
/// Logrows used for aggregation circuit
///
/// check_mode: str
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
///
/// split-proofs: bool
/// Whether the accumulated proofs are segments of a larger circuit
///
/// srs_path: str
/// Path to the SRS used
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// Returns
/// -------
/// bool
///
/// creates an aggregated proof
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
@@ -1374,32 +915,7 @@ fn aggregate(
Ok(true)
}
/// Verifies and aggregate proof
///
/// Arguments
/// ---------
/// proof_path: str
/// The path to the proof file
///
/// vk_path: str
/// The path to the verification key file
///
/// logrows: int
/// logrows used for aggregation circuit
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
/// verifies and aggregate proof
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -1432,32 +948,7 @@ fn verify_aggr(
Ok(true)
}
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
/// creates an EVM compatible verifier, you will need solc installed in your environment to run this
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -1490,26 +981,7 @@ fn create_evm_verifier(
Ok(true)
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// input_data: str
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// Returns
/// -------
/// bool
///
// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
#[pyfunction(signature = (
input_data=PathBuf::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -1531,32 +1003,6 @@ fn create_evm_data_attestation(
Ok(true)
}
/// Setup test evm witness
///
/// Arguments
/// ---------
/// data_path: str
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
///
/// compiled_circuit_path: str
/// The path to the compiled model file (generated using the compile-circuit command)
///
/// test_data: str
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
///
/// input_sources: str
/// Where the input data comes from
///
/// output_source: str
/// Where the output data comes from
///
/// rpc_url: str
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data_path,
compiled_circuit_path,
@@ -1591,7 +1037,6 @@ fn setup_test_evm_witness(
Ok(true)
}
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
@@ -1624,7 +1069,6 @@ fn deploy_evm(
Ok(true)
}
/// deploys the solidity vk verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
@@ -1657,7 +1101,6 @@ fn deploy_vk_evm(
Ok(true)
}
/// deploys the solidity da verifier
#[pyfunction(signature = (
addr_path,
input_data,
@@ -1695,27 +1138,6 @@ fn deploy_da_evm(
Ok(true)
}
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
///
/// rpc_url: str
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
///
/// addr_da: str
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
addr_verifier,
proof_path=PathBuf::from(DEFAULT_PROOF),
@@ -1761,35 +1183,7 @@ fn verify_evm(
Ok(true)
}
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// aggregation_settings: str
/// path to the settings file
///
/// vk_path: str
/// The path to load the desired verification key file
///
/// sol_code_path: str
/// The path to the Solidity code
///
/// abi_path: str
/// The path to output the Solidity verifier ABI
///
/// logrows: int
/// Number of logrows used during aggregated setup
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
#[pyfunction(signature = (
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),

View File

@@ -14,6 +14,9 @@ use maybe_rayon::{
slice::ParallelSliceMut,
};
use serde::{Deserialize, Serialize};
use std::io::BufRead;
use std::io::Write;
use std::path::PathBuf;
pub use val::*;
pub use var::*;
@@ -32,6 +35,7 @@ use halo2_proofs::{
use itertools::Itertools;
use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
@@ -60,9 +64,12 @@ pub enum TensorError {
/// Unsupported operation
#[error("Unsupported operation on a tensor type")]
Unsupported,
/// Overflow
#[error("Unsigned integer overflow or underflow error in op: {0}")]
Overflow(String),
/// File save error
#[error("save error: {0}")]
FileSaveError(String),
/// File load error
#[error("load error: {0}")]
FileLoadError(String),
}
/// The (inner) type of tensor elements.
@@ -469,6 +476,45 @@ impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<T: Clone + TensorType + PrimeField> Tensor<T> {
/// save to a file
pub fn save(&self, path: &PathBuf) -> Result<(), TensorError> {
let writer =
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
let mut buf_writer = std::io::BufWriter::new(writer);
self.inner.iter().map(|x| x.clone()).for_each(|x| {
let x = x.to_repr();
buf_writer.write_all(x.as_ref()).unwrap();
});
Ok(())
}
/// load from a file
pub fn load(path: &PathBuf) -> Result<Self, TensorError> {
let reader =
std::fs::File::open(path).map_err(|e| TensorError::FileLoadError(e.to_string()))?;
let mut buf_reader = std::io::BufReader::new(reader);
let mut inner = Vec::new();
while let Ok(true) = buf_reader.has_data_left() {
let mut repr = T::Repr::default();
match buf_reader.read_exact(repr.as_mut()) {
Ok(_) => {
inner.push(T::from_repr(repr).unwrap());
}
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
))
}
}
}
Ok(Tensor::new(Some(&inner), &[inner.len()]).unwrap())
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -937,7 +983,6 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
assert!(source < self.dims.len());
assert!(destination < self.dims.len());
let mut new_dims = self.dims.clone();
new_dims.remove(source);
new_dims.insert(destination, self.dims[source]);
@@ -969,8 +1014,6 @@ impl<T: Clone + TensorType> Tensor<T> {
old_coord[source - 1] = *c;
} else if (i < source && source < destination)
|| (i < destination && source > destination)
|| (i > source && source > destination)
|| (i > destination && source < destination)
{
old_coord[i] = *c;
} else if i > source && source < destination {
@@ -983,10 +1026,7 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
}
let value = self.get(&old_coord);
output.set(&coord, value);
output.set(&coord, self.get(&old_coord));
}
Ok(output)

File diff suppressed because it is too large Load Diff

View File

@@ -316,12 +316,6 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
inner.into()
}
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
pub fn new_instance(
cs: &mut ConstraintSystem<F>,
@@ -879,13 +873,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
};
Ok(())
}
/// Calls `pad_spatial_dims` on the inner [Tensor].
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
/// Calls `pad` on the inner [Tensor].
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = pad(v, padding, offset)?;
*v = pad(v, padding)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {