mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
15 Commits
release-v1
...
ac/patch-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bbc89cc89 | ||
|
|
28b65f2639 | ||
|
|
9592d38a8f | ||
|
|
2cec49dfc3 | ||
|
|
31a1681ca4 | ||
|
|
134b54d32b | ||
|
|
beb5f12376 | ||
|
|
65be3c84bb | ||
|
|
6f743c57d3 | ||
|
|
ddb54c5a73 | ||
|
|
6e1f22a15b | ||
|
|
da97323bde | ||
|
|
55046feeb6 | ||
|
|
d0d0596e58 | ||
|
|
b78efdcbf4 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish EZKL Engine npm package
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
@@ -79,10 +79,6 @@ jobs:
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
|
||||
|
||||
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
|
||||
run: |
|
||||
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
|
||||
|
||||
- name: Add serialize and deserialize methods to nodejs bundle
|
||||
run: |
|
||||
echo '
|
||||
@@ -178,3 +174,40 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
15
.github/workflows/pypi.yml
vendored
15
.github/workflows/pypi.yml
vendored
@@ -128,7 +128,6 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -360,17 +359,3 @@ jobs:
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
40
.github/workflows/rust.yml
vendored
40
.github/workflows/rust.yml
vendored
@@ -307,8 +307,8 @@ jobs:
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -354,7 +354,7 @@ jobs:
|
||||
|
||||
prove-and-verify-tests:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -394,18 +394,14 @@ jobs:
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (hashed inputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
@@ -416,6 +412,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
@@ -612,24 +610,6 @@ jobs:
|
||||
|
||||
python-integration-tests:
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_USER: ubuntu
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
-v /var/run/postgresql:/var/run/postgresql
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
@@ -654,8 +634,6 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: authenticate-kaggle-cli
|
||||
@@ -673,3 +651,5 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
36
.github/workflows/tagging.yml
vendored
36
.github/workflows/tagging.yml
vendored
@@ -14,40 +14,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
uses: mathieudutour/github-tag-action@v6.1
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set Cargo.toml version to match github tag for docs
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
mv docs/python/src/conf.py docs/python/src/conf.py.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
|
||||
rm docs/python/src/conf.py.orig
|
||||
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
|
||||
rm docs/python/requirements-docs.txt.orig
|
||||
|
||||
- name: Commit files and create tag
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git fetch --tags
|
||||
git checkout -b release-$RELEASE_TAG
|
||||
git add .
|
||||
git commit -m "ci: update version string in docs"
|
||||
git tag -d $RELEASE_TAG
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
branch: release-${{ steps.tag_version.outputs.new_tag }}
|
||||
force: true
|
||||
tags: true
|
||||
|
||||
54
.github/workflows/verify.yml
vendored
54
.github/workflows/verify.yml
vendored
@@ -1,54 +0,0 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,5 +49,4 @@ node_modules
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
docs/python/build
|
||||
!tests/wasm/vk_aggr.key
|
||||
@@ -1 +0,0 @@
|
||||
3.12.1
|
||||
@@ -1,26 +0,0 @@
|
||||
# .readthedocs.yaml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.12"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: ./docs/python/src/conf.py
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
# formats:
|
||||
# - pdf
|
||||
# - epub
|
||||
|
||||
# Optional but recommended, declare the Python requirements required
|
||||
# to build your documentation
|
||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||
python:
|
||||
install:
|
||||
- requirements: ./docs/python/requirements-docs.txt
|
||||
164
Cargo.lock
generated
164
Cargo.lock
generated
@@ -644,6 +644,15 @@ dependencies = [
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
|
||||
dependencies = [
|
||||
"digest 0.10.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blake2b_simd"
|
||||
version = "1.0.2"
|
||||
@@ -1185,6 +1194,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
|
||||
dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
@@ -1307,6 +1325,12 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
|
||||
|
||||
[[package]]
|
||||
name = "dyn-hash"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
@@ -1785,7 +1809,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
|
||||
"halo2curves 0.7.0",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1793,6 +1817,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
"openssl",
|
||||
@@ -2176,10 +2201,11 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.2.1"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
]
|
||||
@@ -2204,19 +2230,23 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
|
||||
"halo2curves 0.7.0",
|
||||
"icicle",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustacuda",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"sha3 0.9.1",
|
||||
"tracing",
|
||||
]
|
||||
@@ -2224,7 +2254,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#eb04be1f7d005e5b9dd3ff41efa30aeb5e0c34a3"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#3082fda94151fc6760a3cb2be4741ddbeef04c03"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
@@ -2300,15 +2330,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.6.1"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c#9fff22c5f72cc54fac1ef3a844e1072b08cfecdf"
|
||||
version = "0.7.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
@@ -2318,11 +2351,25 @@ dependencies = [
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
@@ -2830,6 +2877,16 @@ version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.39"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.0.1"
|
||||
@@ -2993,6 +3050,15 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -3126,6 +3192,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
@@ -3714,6 +3786,12 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@@ -4058,9 +4136,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.9.0"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
@@ -4360,6 +4438,12 @@ version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hex"
|
||||
version = "2.1.0"
|
||||
@@ -4782,11 +4866,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#574b65ea6b4d43eebac5565146519a95b435815c"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
"halo2curves 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"halo2curves 0.6.1",
|
||||
"hex",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
@@ -5127,11 +5211,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.23"
|
||||
version = "0.3.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
|
||||
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"num-conv",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
@@ -5139,16 +5226,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
|
||||
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.10"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
|
||||
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
|
||||
dependencies = [
|
||||
"num-conv",
|
||||
"time-core",
|
||||
]
|
||||
|
||||
@@ -5398,8 +5486,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bit-set",
|
||||
@@ -5422,11 +5510,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"half 2.2.1",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.4.1",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
@@ -5441,8 +5532,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5451,20 +5542,23 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"cc",
|
||||
"derive-new",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"half 2.2.1",
|
||||
"dyn-hash",
|
||||
"half 2.4.1",
|
||||
"lazy_static",
|
||||
"liquid",
|
||||
"liquid-core",
|
||||
"log",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"rayon",
|
||||
"scan_fmt",
|
||||
"smallvec",
|
||||
"time",
|
||||
@@ -5475,8 +5569,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5489,8 +5583,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5506,8 +5600,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -15,9 +15,10 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
mimalloc = "0.1"
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
@@ -32,10 +33,10 @@ log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
@@ -80,7 +81,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
|
||||
@@ -175,7 +176,7 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup"]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -194,6 +195,8 @@ mv-lookup = [
|
||||
"snark-verifier/mv-lookup",
|
||||
"halo2_solidity_verifier/mv-lookup",
|
||||
]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
@@ -204,8 +207,8 @@ no-banner = []
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
|
||||
[profile.release]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(HybridOp::SumPool {
|
||||
padding: vec![(0, 0); 2],
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
kernel_shape: (2, 2),
|
||||
normalized: false,
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/sh
|
||||
sphinx-build ./src build
|
||||
@@ -1,4 +0,0 @@
|
||||
ezkl==10.4.1
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
@@ -1,29 +0,0 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '10.4.1'
|
||||
version = release
|
||||
|
||||
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.inheritance_diagram',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx_rtd_theme',
|
||||
]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
@@ -1,11 +0,0 @@
|
||||
.. extension documentation master file, created by
|
||||
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
ezkl python bindings
|
||||
================================================
|
||||
|
||||
.. automodule:: ezkl
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -203,8 +203,8 @@ where
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
padding: [(PADDING, PADDING); 2],
|
||||
stride: (STRIDE, STRIDE),
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"Make sure you install postgres if needed https://postgresapp.com/. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -21,81 +21,23 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
|
||||
"os.system(\"chmod +x e2pg\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
|
||||
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"os.system(\"echo $RLPS_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(5)\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
|
||||
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
|
||||
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -137,13 +79,11 @@
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
"# import logging\n",
|
||||
"# # # uncomment for more descriptive logging \n",
|
||||
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"# logging.basicConfig(format=FORMAT)\n",
|
||||
"# logging.getLogger().setLevel(logging.DEBUG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -236,7 +176,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
@@ -244,9 +183,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -255,7 +194,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -271,9 +210,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -290,6 +229,22 @@
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -298,21 +253,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"# setup kzg params\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
"res = ezkl.get_srs(params_path, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -362,13 +306,16 @@
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" pk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -384,14 +331,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -416,14 +360,73 @@
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" params_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
"\n",
|
||||
"# verify\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_filename,\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "W7tAa-DFAtvS"
|
||||
},
|
||||
"source": [
|
||||
"# Part 2 (Using the ZK Computational Graph Onchain!)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8Ym91kaVAIB6"
|
||||
},
|
||||
"source": [
|
||||
"**Now How Do We Do It Onchain?????**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 339
|
||||
},
|
||||
"id": "fodkNgwS70FM",
|
||||
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# first we need to create evm verifier\n",
|
||||
"print(vk_path)\n",
|
||||
"print(params_path)\n",
|
||||
"print(settings_filename)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -432,8 +435,51 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# kill all shovel process \n",
|
||||
"os.system(\"pkill -f shovel\")"
|
||||
"# Make sure anvil is running locally first\n",
|
||||
"# run with $ anvil -p 3030\n",
|
||||
"# we use the default anvil node here\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"with open(address_path, 'r') as file:\n",
|
||||
" addr = file.read().rstrip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read the address from addr_path\n",
|
||||
"addr = None\n",
|
||||
"with open(address_path, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.system(\"killall -9 e2pg\");"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -455,7 +501,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -20,16 +20,16 @@
|
||||
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "4.0.0",
|
||||
"@ethereumjs/evm": "2.0.0",
|
||||
"@ethereumjs/statemanager": "2.0.0",
|
||||
"@ethereumjs/tx": "5.0.0",
|
||||
"@ethereumjs/util": "9.0.0",
|
||||
"@ethereumjs/vm": "7.0.0",
|
||||
"@ethersproject/abi": "5.7.0",
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "6.7.1",
|
||||
"json-bigint": "1.0.0"
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
|
||||
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
@@ -6,34 +6,34 @@ settings:
|
||||
|
||||
dependencies:
|
||||
'@ethereumjs/common':
|
||||
specifier: 4.0.0
|
||||
specifier: ^4.0.0
|
||||
version: 4.0.0
|
||||
'@ethereumjs/evm':
|
||||
specifier: 2.0.0
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/statemanager':
|
||||
specifier: 2.0.0
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/tx':
|
||||
specifier: 5.0.0
|
||||
specifier: ^5.0.0
|
||||
version: 5.0.0
|
||||
'@ethereumjs/util':
|
||||
specifier: 9.0.0
|
||||
specifier: ^9.0.0
|
||||
version: 9.0.0
|
||||
'@ethereumjs/vm':
|
||||
specifier: 7.0.0
|
||||
specifier: ^7.0.0
|
||||
version: 7.0.0
|
||||
'@ethersproject/abi':
|
||||
specifier: 5.7.0
|
||||
specifier: ^5.7.0
|
||||
version: 5.7.0
|
||||
'@ezkljs/engine':
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
ethers:
|
||||
specifier: 6.7.1
|
||||
specifier: ^6.7.1
|
||||
version: 6.7.1
|
||||
json-bigint:
|
||||
specifier: 1.0.0
|
||||
specifier: ^1.0.0
|
||||
version: 1.0.0
|
||||
|
||||
devDependencies:
|
||||
|
||||
@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
|
||||
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
|
||||
# Add the ezkl directory to the path and ensure the old PATH variables remain.
|
||||
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
|
||||
fi
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
requires = ["maturin>=0.14,<0.15"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -2,7 +2,7 @@ attrs==23.2.0
|
||||
exceptiongroup==1.2.0
|
||||
importlib-metadata==7.1.0
|
||||
iniconfig==2.0.0
|
||||
maturin==1.5.1
|
||||
maturin==1.5.0
|
||||
packaging==24.0
|
||||
pluggy==1.4.0
|
||||
pytest==8.1.1
|
||||
@@ -11,4 +11,4 @@ typing-extensions==4.10.0
|
||||
zipp==3.18.1
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.17.1
|
||||
numpy==1.26.4
|
||||
numpy==1.26.4
|
||||
@@ -1,4 +1,6 @@
|
||||
// ignore file if compiling for wasm
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use clap::Parser;
|
||||
|
||||
@@ -956,6 +956,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
op.layout(self, region, values)
|
||||
let res = op.layout(self, region, values)?;
|
||||
|
||||
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
|
||||
if let Some(claimed_output) = &res {
|
||||
// during key generation this will be unknown vals so we use this as a flag to check
|
||||
let mut is_assigned = !claimed_output.any_unknowns()?;
|
||||
for val in values.iter() {
|
||||
is_assigned = is_assigned && !val.any_unknowns()?;
|
||||
}
|
||||
if is_assigned {
|
||||
op.safe_mode_check(claimed_output, values)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::i128_to_felt,
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -29,15 +29,15 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
SumPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
normalized: bool,
|
||||
},
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
MaxPool2d {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
pool_dims: (usize, usize),
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -85,6 +85,93 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
}
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => tensor::ops::nonlinearities::softmax_axes(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
axes,
|
||||
),
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater(&x, &y)?
|
||||
}
|
||||
HybridOp::GreaterEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Less => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less(&x, &y)?
|
||||
}
|
||||
HybridOp::LessEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Equals => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::equals(&x, &y)?
|
||||
}
|
||||
};
|
||||
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
@@ -114,12 +201,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
HybridOp::MaxPool {
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => format!(
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
@@ -166,9 +253,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
*padding,
|
||||
*stride,
|
||||
*kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
@@ -228,17 +315,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::MaxPool {
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => layouts::max_pool(
|
||||
} => layouts::max_pool2d(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
*padding,
|
||||
*stride,
|
||||
*pool_dims,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -136,11 +136,61 @@ impl LookupOp {
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "abs".into(),
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Sign => "sign".into(),
|
||||
LookupOp::LessThan { a } => format!("less_than_{}", a),
|
||||
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
|
||||
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::ReLU => "relu".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
|
||||
LookupOp::Sin { scale } => format!("sin_{}", scale),
|
||||
LookupOp::ASin { scale } => format!("asin_{}", scale),
|
||||
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
|
||||
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
|
||||
LookupOp::Tan { scale } => format!("tan_{}", scale),
|
||||
LookupOp::ATan { scale } => format!("atan_{}", scale),
|
||||
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
|
||||
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
|
||||
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
@@ -232,13 +282,6 @@ impl LookupOp {
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -35,6 +35,8 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
fn as_string(&self) -> String;
|
||||
|
||||
@@ -69,6 +71,33 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Safe mode output checl
|
||||
fn safe_mode_check(
|
||||
&self,
|
||||
claimed_output: &ValTensor<F>,
|
||||
original_values: &[ValTensor<F>],
|
||||
) -> Result<(), TensorError> {
|
||||
let felt_evals = original_values
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
|
||||
evals.reshape(v.dims())?;
|
||||
Ok(evals)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
|
||||
|
||||
let mut output = claimed_output
|
||||
.get_felt_evals()
|
||||
.map_err(|_| TensorError::FeltError)?;
|
||||
output.reshape(claimed_output.dims())?;
|
||||
|
||||
assert_eq!(output, ref_op);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
@@ -147,6 +176,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
self
|
||||
}
|
||||
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Input".into()
|
||||
}
|
||||
@@ -200,6 +235,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Err(TensorError::WrongMethod)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Unknown".into()
|
||||
@@ -269,6 +307,11 @@ impl<
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("CONST (scale={})", self.quantized_values.scale().unwrap())
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -31,8 +32,8 @@ pub enum PolyOp {
|
||||
equation: String,
|
||||
},
|
||||
Conv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -40,9 +41,9 @@ pub enum PolyOp {
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
output_padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -57,13 +58,10 @@ pub enum PolyOp {
|
||||
destination: usize,
|
||||
},
|
||||
Flatten(Vec<usize>),
|
||||
Pad(Vec<(usize, usize)>),
|
||||
Pad([(usize, usize); 2]),
|
||||
Sum {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
MeanOfSquares {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Prod {
|
||||
axes: Vec<usize>,
|
||||
len_prod: usize,
|
||||
@@ -107,28 +105,10 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => format!(
|
||||
"GATHERND (batch_dims={}, constant_idx{})",
|
||||
batch_dims,
|
||||
indices.is_some()
|
||||
),
|
||||
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
|
||||
PolyOp::ScatterElements { dim, constant_idx } => format!(
|
||||
"SCATTERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
|
||||
}
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -140,26 +120,15 @@ impl<
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
PolyOp::Add => "ADD".into(),
|
||||
PolyOp::Mult => "MULT".into(),
|
||||
PolyOp::Sub => "SUB".into(),
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
)
|
||||
}
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
|
||||
@@ -173,6 +142,146 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut inputs = inputs.to_vec();
|
||||
let res = match &self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"multibroadcastto inputs".to_string(),
|
||||
));
|
||||
}
|
||||
inputs[0].expand(shape)
|
||||
}
|
||||
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
|
||||
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
|
||||
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
|
||||
PolyOp::Not => tensor::ops::not(&inputs[0]),
|
||||
PolyOp::Downsample {
|
||||
axis,
|
||||
stride,
|
||||
modulo,
|
||||
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => inputs[0].move_axis(*source, *destination),
|
||||
PolyOp::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::Pad(p) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pad inputs".to_string()));
|
||||
}
|
||||
tensor::ops::pad(&inputs[0], *p)
|
||||
}
|
||||
PolyOp::Add => tensor::ops::add(&inputs),
|
||||
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
|
||||
PolyOp::Sub => tensor::ops::sub(&inputs),
|
||||
PolyOp::Mult => tensor::ops::mult(&inputs),
|
||||
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
}
|
||||
inputs[0].pow(*u)
|
||||
}
|
||||
PolyOp::Sum { axes } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("sum inputs".to_string()));
|
||||
}
|
||||
tensor::ops::sum_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("prod inputs".to_string()));
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
@@ -183,9 +292,6 @@ impl<
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
PolyOp::MeanOfSquares { axes } => {
|
||||
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
|
||||
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
|
||||
@@ -212,7 +318,7 @@ impl<
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -264,9 +370,9 @@ impl<
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
*padding,
|
||||
*output_padding,
|
||||
*stride,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
@@ -282,7 +388,7 @@ impl<
|
||||
)));
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(p.clone(), 0)?;
|
||||
input.pad(*p)?;
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
@@ -298,7 +404,6 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::{
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
@@ -25,6 +27,13 @@ pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
/// an optional directory to read and write the lookup table cache
|
||||
static ref LOOKUP_CACHE: Option<std::path::PathBuf> = std::env::var("LOOKUP_CACHE")
|
||||
.ok()
|
||||
.map(std::path::PathBuf::from);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
///
|
||||
pub struct SelectorConstructor<F: PrimeField> {
|
||||
@@ -111,10 +120,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let chunk = chunk as i128;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
.unwrap();
|
||||
let op_f = Op::<F>::f(
|
||||
&self.nonlinearity,
|
||||
&[Tensor::from(vec![first_element].into_iter())],
|
||||
)
|
||||
.unwrap();
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
@@ -202,8 +212,46 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(i128_to_felt(x)))?;
|
||||
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
|
||||
Ok((inputs, evals.output))
|
||||
};
|
||||
|
||||
let (inputs, evals) = if let Some(cache) = &*LOOKUP_CACHE {
|
||||
let cache_path = cache.join(self.nonlinearity.as_path());
|
||||
let input_path = cache_path.join("inputs");
|
||||
let output_path = cache_path.join("outputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading lookup table from cache: {:?}", cache_path);
|
||||
let (input_cache, output_cache) =
|
||||
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
|
||||
(input_cache, output_cache)
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let (inputs, evals) = gen_table()?;
|
||||
inputs.save(&input_path)?;
|
||||
evals.save(&output_path)?;
|
||||
|
||||
(inputs, evals)
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table {} without cache",
|
||||
self.nonlinearity.as_path()
|
||||
);
|
||||
|
||||
gen_table()?
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
@@ -235,7 +283,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
)?;
|
||||
}
|
||||
|
||||
let output = evals.output[row_offset];
|
||||
let output = evals[row_offset];
|
||||
|
||||
table.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
@@ -273,6 +321,11 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
format!("rangecheck_{}_{}", self.range.0, self.range.1)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i128;
|
||||
@@ -350,7 +403,31 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let inputs: Tensor<F> = if let Some(cache) = &*LOOKUP_CACHE {
|
||||
let cache_path = cache.join(self.as_path());
|
||||
let input_path = cache_path.join("inputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading range check table from cache: {:?}", cache_path);
|
||||
Tensor::load(&input_path)?
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating range check table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
inputs.save(&input_path)?;
|
||||
inputs
|
||||
}
|
||||
} else {
|
||||
log::info!("Generating range check {} without cache", self.as_path());
|
||||
|
||||
Tensor::from(smallest..=largest).map(|x| i128_to_felt(x))
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -1048,8 +1048,8 @@ mod conv {
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -455,7 +455,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
|
||||
for val in flattened_instances.clone() {
|
||||
let bytes = val.to_repr();
|
||||
let u = U256::from_little_endian(bytes.as_slice());
|
||||
let u = U256::from_little_endian(bytes.inner());
|
||||
public_inputs.push(u);
|
||||
}
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -636,7 +637,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn gen_witness(
|
||||
pub(crate) async fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
@@ -659,7 +660,7 @@ pub(crate) fn gen_witness(
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -1228,14 +1229,22 @@ pub(crate) fn calibrate(
|
||||
);
|
||||
|
||||
if matches!(target, CalibrationTarget::Resources { col_overflow: true }) {
|
||||
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
|
||||
let module_log_row = best_params.module_constraint_logrows_with_blinding();
|
||||
let instance_logrows = best_params.log2_total_instances_with_blinding();
|
||||
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
|
||||
|
||||
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
|
||||
reduction = std::cmp::max(reduction, instance_logrows);
|
||||
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
|
||||
let lookup_log_rows = ((best_params.run_args.lookup_range.1
|
||||
- best_params.run_args.lookup_range.0) as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
+ 1;
|
||||
let mut reduction = std::cmp::max(
|
||||
(best_params
|
||||
.model_instance_shapes
|
||||
.iter()
|
||||
.map(|x| x.iter().product::<usize>())
|
||||
.sum::<usize>() as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
+ 1,
|
||||
lookup_log_rows,
|
||||
);
|
||||
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
|
||||
|
||||
info!(
|
||||
|
||||
@@ -21,6 +21,8 @@ use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::thread;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
@@ -232,15 +234,21 @@ impl PostgresSource {
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config, NoTls)?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[])? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
|
||||
let mut client = Client::connect(&config, NoTls).unwrap();
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).unwrap() {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
})
|
||||
.join()
|
||||
.map_err(|_| "failed to fetch data from postgres")?;
|
||||
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
|
||||
@@ -483,22 +483,7 @@ pub struct GraphSettings {
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
/// Calc the number of rows required for lookup tables
|
||||
pub fn lookup_log_rows(&self) -> u32 {
|
||||
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// Calc the number of rows required for lookup tables
|
||||
pub fn lookup_log_rows_with_blinding(&self) -> u32 {
|
||||
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32
|
||||
+ RESERVED_BLINDING_ROWS as f32)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn model_constraint_logrows_with_blinding(&self) -> u32 {
|
||||
fn model_constraint_logrows(&self) -> u32 {
|
||||
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
@@ -510,31 +495,14 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the dynamic lookup and shuffle
|
||||
pub fn dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64
|
||||
+ self.total_shuffle_col_size as f64
|
||||
+ RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
|
||||
self.total_dynamic_col_size + self.total_shuffle_col_size
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the module constraints
|
||||
pub fn module_constraint_logrows(&self) -> u32 {
|
||||
fn module_constraint_logrows(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the number of rows required for the module constraints
|
||||
pub fn module_constraint_logrows_with_blinding(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn constants_logrows(&self) -> u32 {
|
||||
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
|
||||
.log2()
|
||||
@@ -561,14 +529,6 @@ impl GraphSettings {
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// calculate the log2 of the total number of instances
|
||||
pub fn log2_total_instances_with_blinding(&self) -> u32 {
|
||||
let sum = self.total_instances().iter().sum::<usize>() + RESERVED_BLINDING_ROWS;
|
||||
|
||||
// max between 1 and the log2 of the sums
|
||||
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
|
||||
}
|
||||
|
||||
/// save params to file
|
||||
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
|
||||
// buf writer
|
||||
@@ -958,7 +918,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn load_graph_input(
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
@@ -968,6 +928,7 @@ impl GraphCircuit {
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -991,7 +952,7 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
fn process_data_source(
|
||||
async fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
@@ -1004,16 +965,8 @@ impl GraphCircuit {
|
||||
for (i, shape) in shapes.iter().enumerate() {
|
||||
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
|
||||
}
|
||||
|
||||
// start runtime and fetch data
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
})
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
}
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
@@ -1173,7 +1126,7 @@ impl GraphCircuit {
|
||||
);
|
||||
|
||||
// These are upper limits, going above these is wasteful, but they are not hard limits
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
|
||||
@@ -1200,20 +1200,6 @@ impl Model {
|
||||
.collect();
|
||||
|
||||
for (idx, node) in self.graph.nodes.iter() {
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1225,11 +1211,25 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input dims {:?}",
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -60,6 +61,20 @@ impl Op<Fp> for Rescaled {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
|
||||
let res = (ri.clone() * mult_tensor)?;
|
||||
rescaled_inputs.push(res);
|
||||
}
|
||||
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
@@ -200,6 +215,13 @@ impl Op<Fp> for RebaseScale {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
@@ -367,6 +389,13 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
fn f(
|
||||
&self,
|
||||
inputs: &[Tensor<Fp>],
|
||||
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
|
||||
self.as_op().f(inputs)
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
|
||||
@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -734,19 +734,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Sum { axes })
|
||||
}
|
||||
"Reduce<MeanOfSquares>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"mean of squares".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
|
||||
}
|
||||
|
||||
"Max" => {
|
||||
// Extract the max value
|
||||
// first find the input that is a constant
|
||||
@@ -1119,7 +1106,17 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1127,10 +1124,26 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
stride: (stride_h, stride_w),
|
||||
pool_dims: (kernel_height, kernel_width),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
@@ -1152,7 +1165,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
@@ -1192,7 +1205,15 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
Some(s) => {
|
||||
if s.len() == 1 {
|
||||
(s[0], s[0])
|
||||
} else if s.len() == 2 {
|
||||
(s[0], s[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
@@ -1200,7 +1221,17 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let padding = match &conv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1255,20 +1286,33 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
Some(s) => (s[0], s[1]),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
|
||||
let output_padding: (usize, usize) =
|
||||
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1287,7 +1331,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
output_padding,
|
||||
stride,
|
||||
})
|
||||
}
|
||||
@@ -1388,17 +1432,46 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"kernel shape".to_string(),
|
||||
)));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
stride: (stride_h, stride_w),
|
||||
kernel_shape: (kernel_height, kernel_width),
|
||||
normalized: sumpool_node.normalize,
|
||||
})
|
||||
}
|
||||
@@ -1425,7 +1498,29 @@ pub fn new_op_from_onnx(
|
||||
)));
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
|
||||
let padding_len = pad_node.pads.len();
|
||||
|
||||
// we only support symmetrical padding that affects the last 2 dims (height and width params)
|
||||
for (i, pad_params) in pad_node.pads.iter().enumerate() {
|
||||
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
"ezkl currently only supports padding height and width dimensions"
|
||||
.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let padding = [
|
||||
(
|
||||
pad_node.pads[padding_len - 2].0,
|
||||
pad_node.pads[padding_len - 1].0,
|
||||
),
|
||||
(
|
||||
pad_node.pads[padding_len - 2].1,
|
||||
pad_node.pads[padding_len - 1].1,
|
||||
),
|
||||
];
|
||||
SupportedOp::Linear(PolyOp::Pad(padding))
|
||||
}
|
||||
"RmAxis" | "Reshape" | "AddAxis" => {
|
||||
// Extract the slope layer hyperparams
|
||||
|
||||
14
src/lib.rs
14
src/lib.rs
@@ -23,7 +23,7 @@
|
||||
)]
|
||||
// we allow this for our dynamic range based indexing scheme
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
|
||||
#![feature(buf_read_has_data_left)]
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
@@ -200,13 +200,13 @@ pub struct RunArgs {
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, hashed, fixed, kzgcommit
|
||||
/// Flags whether inputs are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
pub input_visibility: Visibility,
|
||||
/// Flags whether outputs are public, private, fixed, hashed, kzgcommit
|
||||
/// Flags whether outputs are public, private, hashed
|
||||
#[arg(long, default_value = "public")]
|
||||
pub output_visibility: Visibility,
|
||||
/// Flags whether params are fixed, private, hashed, kzgcommit
|
||||
/// Flags whether params are public, private, hashed
|
||||
#[arg(long, default_value = "private")]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
@@ -248,12 +248,6 @@ impl Default for RunArgs {
|
||||
impl RunArgs {
|
||||
///
|
||||
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if self.param_visibility == Visibility::Public {
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ pub fn init_logger() {
|
||||
prefix_token(&record.level()),
|
||||
// pretty print UTC time
|
||||
chrono::Utc::now()
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.format("%Y-%m-%d %H:%M:%S:%3f")
|
||||
.to_string()
|
||||
.bright_magenta(),
|
||||
record.metadata().target(),
|
||||
|
||||
@@ -550,7 +550,8 @@ where
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
Scheme::Curve: Serialize + DeserializeOwned + SerdeObject,
|
||||
Scheme::ParamsProver: Send + Sync,
|
||||
{
|
||||
let strategy = Strategy::new(params.verifier_params());
|
||||
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
|
||||
717
src/python.rs
717
src/python.rs
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,9 @@ use maybe_rayon::{
|
||||
slice::ParallelSliceMut,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::BufRead;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
@@ -32,6 +35,7 @@ use halo2_proofs::{
|
||||
use itertools::Itertools;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
@@ -60,9 +64,12 @@ pub enum TensorError {
|
||||
/// Unsupported operation
|
||||
#[error("Unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("Unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
/// File save error
|
||||
#[error("save error: {0}")]
|
||||
FileSaveError(String),
|
||||
/// File load error
|
||||
#[error("load error: {0}")]
|
||||
FileLoadError(String),
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
@@ -469,6 +476,45 @@ impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
/// save to a file
|
||||
pub fn save(&self, path: &PathBuf) -> Result<(), TensorError> {
|
||||
let writer =
|
||||
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
|
||||
let mut buf_writer = std::io::BufWriter::new(writer);
|
||||
|
||||
self.inner.iter().map(|x| x.clone()).for_each(|x| {
|
||||
let x = x.to_repr();
|
||||
buf_writer.write_all(x.as_ref()).unwrap();
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// load from a file
|
||||
pub fn load(path: &PathBuf) -> Result<Self, TensorError> {
|
||||
let reader =
|
||||
std::fs::File::open(path).map_err(|e| TensorError::FileLoadError(e.to_string()))?;
|
||||
let mut buf_reader = std::io::BufReader::new(reader);
|
||||
|
||||
let mut inner = Vec::new();
|
||||
while let Ok(true) = buf_reader.has_data_left() {
|
||||
let mut repr = T::Repr::default();
|
||||
match buf_reader.read_exact(repr.as_mut()) {
|
||||
Ok(_) => {
|
||||
inner.push(T::from_repr(repr).unwrap());
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Tensor::new(Some(&inner), &[inner.len()]).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Sets (copies) the tensor values to the provided ones.
|
||||
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
|
||||
@@ -937,7 +983,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
|
||||
assert!(source < self.dims.len());
|
||||
assert!(destination < self.dims.len());
|
||||
|
||||
let mut new_dims = self.dims.clone();
|
||||
new_dims.remove(source);
|
||||
new_dims.insert(destination, self.dims[source]);
|
||||
@@ -969,8 +1014,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
old_coord[source - 1] = *c;
|
||||
} else if (i < source && source < destination)
|
||||
|| (i < destination && source > destination)
|
||||
|| (i > source && source > destination)
|
||||
|| (i > destination && source < destination)
|
||||
{
|
||||
old_coord[i] = *c;
|
||||
} else if i > source && source < destination {
|
||||
@@ -983,10 +1026,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let value = self.get(&old_coord);
|
||||
|
||||
output.set(&coord, value);
|
||||
output.set(&coord, self.get(&old_coord));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
|
||||
2116
src/tensor/ops.rs
2116
src/tensor/ops.rs
File diff suppressed because it is too large
Load Diff
@@ -316,12 +316,6 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
|
||||
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
|
||||
pub fn new_instance(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -879,13 +873,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
/// Calls `pad_spatial_dims` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
|
||||
/// Calls `pad` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = pad(v, padding, offset)?;
|
||||
*v = pad(v, padding)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
|
||||
40
src/wasm.rs
40
src/wasm.rs
@@ -1,4 +1,3 @@
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::PoseidonChip;
|
||||
use crate::circuit::modules::Module;
|
||||
@@ -148,45 +147,6 @@ pub fn floatToFelt(
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Generate a kzg commitment.
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn kzgCommit(
|
||||
message: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(¶ms_ser[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
|
||||
@@ -899,7 +899,7 @@ mod native_tests {
|
||||
seq!(N in 0..=45 {
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_(test: &str) {
|
||||
fn prove_and_verify_with_overflow_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
// crate::native_tests::init_wasm();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
@@ -912,20 +912,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_hashed_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
// crate::native_tests::init_wasm();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "hashed", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_fixed_params_(test: &str) {
|
||||
fn prove_and_verify_with_overflow_fixed_params_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
// crate::native_tests::init_wasm();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
|
||||
@@ -1,29 +1,21 @@
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
#[cfg(test)]
|
||||
mod wasm32 {
|
||||
use ezkl::circuit::modules::polycommit::PolyCommitChip;
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::PoseidonChip;
|
||||
use ezkl::circuit::modules::Module;
|
||||
use ezkl::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use ezkl::graph::GraphCircuit;
|
||||
use ezkl::graph::{GraphSettings, GraphWitness};
|
||||
use ezkl::graph::GraphWitness;
|
||||
use ezkl::pfsys;
|
||||
use ezkl::wasm::{
|
||||
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
|
||||
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
|
||||
kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation,
|
||||
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
|
||||
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
|
||||
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::Bn256;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use wasm_bindgen::JsError;
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
use wasm_bindgen_test::*;
|
||||
@@ -98,46 +90,6 @@ mod wasm32 {
|
||||
assert_eq!(calldata, reference_calldata);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn verify_kzg_commit() {
|
||||
// create a vector of field elements Vec<Fr> and assign it to the message variable
|
||||
let mut message: Vec<Fr> = vec![];
|
||||
for i in 0..32 {
|
||||
message.push(Fr::from(i as u64));
|
||||
}
|
||||
let message_ser = serde_json::to_vec(&message).unwrap();
|
||||
|
||||
let settings: GraphSettings = serde_json::from_slice(&SETTINGS).unwrap();
|
||||
let mut reader = std::io::BufReader::new(SRS);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).unwrap();
|
||||
let mut reader = std::io::BufReader::new(VK);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
settings.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let commitment_ser = kzgCommit(
|
||||
wasm_bindgen::Clamped(message_ser),
|
||||
wasm_bindgen::Clamped(VK.to_vec()),
|
||||
wasm_bindgen::Clamped(SETTINGS.to_vec()),
|
||||
wasm_bindgen::Clamped(SRS.to_vec()),
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
let commitment: Vec<halo2curves::bn256::G1Affine> =
|
||||
serde_json::from_slice(&commitment_ser[..]).unwrap();
|
||||
let reference_commitment = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
|
||||
assert_eq!(commitment, reference_commitment);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn verify_field_serialization_roundtrip() {
|
||||
for i in 0..32 {
|
||||
|
||||
Reference in New Issue
Block a user