mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
15 Commits
v10.3.3
...
ac/patch-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bbc89cc89 | ||
|
|
28b65f2639 | ||
|
|
9592d38a8f | ||
|
|
2cec49dfc3 | ||
|
|
31a1681ca4 | ||
|
|
134b54d32b | ||
|
|
beb5f12376 | ||
|
|
65be3c84bb | ||
|
|
6f743c57d3 | ||
|
|
ddb54c5a73 | ||
|
|
6e1f22a15b | ||
|
|
da97323bde | ||
|
|
55046feeb6 | ||
|
|
d0d0596e58 | ||
|
|
b78efdcbf4 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish EZKL Engine npm package
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
@@ -79,10 +79,6 @@ jobs:
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
|
||||
|
||||
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
|
||||
run: |
|
||||
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
|
||||
|
||||
- name: Add serialize and deserialize methods to nodejs bundle
|
||||
run: |
|
||||
echo '
|
||||
@@ -178,3 +174,40 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
15
.github/workflows/pypi.yml
vendored
15
.github/workflows/pypi.yml
vendored
@@ -128,7 +128,6 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -360,17 +359,3 @@ jobs:
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
28
.github/workflows/rust.yml
vendored
28
.github/workflows/rust.yml
vendored
@@ -307,8 +307,8 @@ jobs:
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -610,24 +610,6 @@ jobs:
|
||||
|
||||
python-integration-tests:
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_USER: ubuntu
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
-v /var/run/postgresql:/var/run/postgresql
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
@@ -652,8 +634,6 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: authenticate-kaggle-cli
|
||||
@@ -671,3 +651,5 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
36
.github/workflows/tagging.yml
vendored
36
.github/workflows/tagging.yml
vendored
@@ -14,40 +14,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
uses: mathieudutour/github-tag-action@v6.1
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set Cargo.toml version to match github tag for docs
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
mv docs/python/src/conf.py docs/python/src/conf.py.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
|
||||
rm docs/python/src/conf.py.orig
|
||||
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
|
||||
rm docs/python/requirements-docs.txt.orig
|
||||
|
||||
- name: Commit files and create tag
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git fetch --tags
|
||||
git checkout -b release-$RELEASE_TAG
|
||||
git add .
|
||||
git commit -m "ci: update version string in docs"
|
||||
git tag -d $RELEASE_TAG
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
branch: release-${{ steps.tag_version.outputs.new_tag }}
|
||||
force: true
|
||||
tags: true
|
||||
|
||||
54
.github/workflows/verify.yml
vendored
54
.github/workflows/verify.yml
vendored
@@ -1,54 +0,0 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,5 +49,4 @@ node_modules
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
docs/python/build
|
||||
!tests/wasm/vk_aggr.key
|
||||
@@ -1 +0,0 @@
|
||||
3.12.1
|
||||
@@ -1,26 +0,0 @@
|
||||
# .readthedocs.yaml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.12"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: ./docs/python/src/conf.py
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
# formats:
|
||||
# - pdf
|
||||
# - epub
|
||||
|
||||
# Optional but recommended, declare the Python requirements required
|
||||
# to build your documentation
|
||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||
python:
|
||||
install:
|
||||
- requirements: ./docs/python/requirements-docs.txt
|
||||
164
Cargo.lock
generated
164
Cargo.lock
generated
@@ -644,6 +644,15 @@ dependencies = [
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
|
||||
dependencies = [
|
||||
"digest 0.10.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blake2b_simd"
|
||||
version = "1.0.2"
|
||||
@@ -1185,6 +1194,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
|
||||
dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
@@ -1307,6 +1325,12 @@ version = "1.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
|
||||
|
||||
[[package]]
|
||||
name = "dyn-hash"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
@@ -1785,7 +1809,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
|
||||
"halo2curves 0.7.0",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1793,6 +1817,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
"openssl",
|
||||
@@ -2176,10 +2201,11 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.2.1"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
]
|
||||
@@ -2204,19 +2230,23 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.6.1 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c)",
|
||||
"halo2curves 0.7.0",
|
||||
"icicle",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustacuda",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"sha3 0.9.1",
|
||||
"tracing",
|
||||
]
|
||||
@@ -2224,7 +2254,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#eb04be1f7d005e5b9dd3ff41efa30aeb5e0c34a3"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#3082fda94151fc6760a3cb2be4741ddbeef04c03"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
@@ -2300,15 +2330,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.6.1"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=9fff22c#9fff22c5f72cc54fac1ef3a844e1072b08cfecdf"
|
||||
version = "0.7.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
@@ -2318,11 +2351,25 @@ dependencies = [
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
@@ -2830,6 +2877,16 @@ version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.39"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.0.1"
|
||||
@@ -2993,6 +3050,15 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -3126,6 +3192,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
@@ -3714,6 +3786,12 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@@ -4058,9 +4136,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.9.0"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
@@ -4360,6 +4438,12 @@ version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hex"
|
||||
version = "2.1.0"
|
||||
@@ -4782,11 +4866,11 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#574b65ea6b4d43eebac5565146519a95b435815c"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
"halo2curves 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"halo2curves 0.6.1",
|
||||
"hex",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
@@ -5127,11 +5211,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.23"
|
||||
version = "0.3.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
|
||||
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"num-conv",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
@@ -5139,16 +5226,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
|
||||
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.10"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
|
||||
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
|
||||
dependencies = [
|
||||
"num-conv",
|
||||
"time-core",
|
||||
]
|
||||
|
||||
@@ -5398,8 +5486,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bit-set",
|
||||
@@ -5422,11 +5510,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"half 2.2.1",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.4.1",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
@@ -5441,8 +5532,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5451,20 +5542,23 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"cc",
|
||||
"derive-new",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"half 2.2.1",
|
||||
"dyn-hash",
|
||||
"half 2.4.1",
|
||||
"lazy_static",
|
||||
"liquid",
|
||||
"liquid-core",
|
||||
"log",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"rayon",
|
||||
"scan_fmt",
|
||||
"smallvec",
|
||||
"time",
|
||||
@@ -5475,8 +5569,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5489,8 +5583,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5506,8 +5600,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=40c64319291184814d9fea5fdf4fa16f5a4f7116#40c64319291184814d9fea5fdf4fa16f5a4f7116"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -15,9 +15,10 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
mimalloc = "0.1"
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
@@ -32,10 +33,10 @@ log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
@@ -80,7 +81,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
|
||||
@@ -175,7 +176,7 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup"]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -194,6 +195,8 @@ mv-lookup = [
|
||||
"snark-verifier/mv-lookup",
|
||||
"halo2_solidity_verifier/mv-lookup",
|
||||
]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
@@ -204,8 +207,8 @@ no-banner = []
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
|
||||
[profile.release]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(HybridOp::SumPool {
|
||||
padding: vec![(0, 0); 2],
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
kernel_shape: (2, 2),
|
||||
normalized: false,
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/sh
|
||||
sphinx-build ./src build
|
||||
@@ -1,4 +0,0 @@
|
||||
ezkl==10.3.3
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
@@ -1,29 +0,0 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '10.3.3'
|
||||
version = release
|
||||
|
||||
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.inheritance_diagram',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx_rtd_theme',
|
||||
]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
@@ -1,11 +0,0 @@
|
||||
.. extension documentation master file, created by
|
||||
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
ezkl python bindings
|
||||
================================================
|
||||
|
||||
.. automodule:: ezkl
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -203,8 +203,8 @@ where
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
padding: [(PADDING, PADDING); 2],
|
||||
stride: (STRIDE, STRIDE),
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"Make sure you install postgres if needed https://postgresapp.com/. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -21,81 +21,23 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
|
||||
"os.system(\"chmod +x e2pg\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
|
||||
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"os.system(\"echo $RLPS_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(5)\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
|
||||
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
|
||||
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -137,13 +79,11 @@
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
"# import logging\n",
|
||||
"# # # uncomment for more descriptive logging \n",
|
||||
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"# logging.basicConfig(format=FORMAT)\n",
|
||||
"# logging.getLogger().setLevel(logging.DEBUG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -236,7 +176,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
@@ -244,9 +183,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -255,7 +194,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -271,9 +210,9 @@
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"dbname\": \"e2pg\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
@@ -290,6 +229,22 @@
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -298,21 +253,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"# setup kzg params\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
"res = ezkl.get_srs(params_path, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -362,13 +306,16 @@
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"params_path = os.path.join('kzg.params')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" pk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -384,14 +331,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -416,14 +360,73 @@
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" params_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
"\n",
|
||||
"# verify\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_filename,\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "W7tAa-DFAtvS"
|
||||
},
|
||||
"source": [
|
||||
"# Part 2 (Using the ZK Computational Graph Onchain!)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "8Ym91kaVAIB6"
|
||||
},
|
||||
"source": [
|
||||
"**Now How Do We Do It Onchain?????**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 339
|
||||
},
|
||||
"id": "fodkNgwS70FM",
|
||||
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# first we need to create evm verifier\n",
|
||||
"print(vk_path)\n",
|
||||
"print(params_path)\n",
|
||||
"print(settings_filename)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" params_path,\n",
|
||||
" settings_filename,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -432,8 +435,51 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# kill all shovel process \n",
|
||||
"os.system(\"pkill -f shovel\")"
|
||||
"# Make sure anvil is running locally first\n",
|
||||
"# run with $ anvil -p 3030\n",
|
||||
"# we use the default anvil node here\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"with open(address_path, 'r') as file:\n",
|
||||
" addr = file.read().rstrip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read the address from addr_path\n",
|
||||
"addr = None\n",
|
||||
"with open(address_path, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.system(\"killall -9 e2pg\");"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -455,7 +501,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -20,16 +20,16 @@
|
||||
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "4.0.0",
|
||||
"@ethereumjs/evm": "2.0.0",
|
||||
"@ethereumjs/statemanager": "2.0.0",
|
||||
"@ethereumjs/tx": "5.0.0",
|
||||
"@ethereumjs/util": "9.0.0",
|
||||
"@ethereumjs/vm": "7.0.0",
|
||||
"@ethersproject/abi": "5.7.0",
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "6.7.1",
|
||||
"json-bigint": "1.0.0"
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
|
||||
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
@@ -6,34 +6,34 @@ settings:
|
||||
|
||||
dependencies:
|
||||
'@ethereumjs/common':
|
||||
specifier: 4.0.0
|
||||
specifier: ^4.0.0
|
||||
version: 4.0.0
|
||||
'@ethereumjs/evm':
|
||||
specifier: 2.0.0
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/statemanager':
|
||||
specifier: 2.0.0
|
||||
specifier: ^2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/tx':
|
||||
specifier: 5.0.0
|
||||
specifier: ^5.0.0
|
||||
version: 5.0.0
|
||||
'@ethereumjs/util':
|
||||
specifier: 9.0.0
|
||||
specifier: ^9.0.0
|
||||
version: 9.0.0
|
||||
'@ethereumjs/vm':
|
||||
specifier: 7.0.0
|
||||
specifier: ^7.0.0
|
||||
version: 7.0.0
|
||||
'@ethersproject/abi':
|
||||
specifier: 5.7.0
|
||||
specifier: ^5.7.0
|
||||
version: 5.7.0
|
||||
'@ezkljs/engine':
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
ethers:
|
||||
specifier: 6.7.1
|
||||
specifier: ^6.7.1
|
||||
version: 6.7.1
|
||||
json-bigint:
|
||||
specifier: 1.0.0
|
||||
specifier: ^1.0.0
|
||||
version: 1.0.0
|
||||
|
||||
devDependencies:
|
||||
|
||||
@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
|
||||
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
|
||||
# Add the ezkl directory to the path and ensure the old PATH variables remain.
|
||||
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
|
||||
fi
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
requires = ["maturin>=0.14,<0.15"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -2,7 +2,7 @@ attrs==23.2.0
|
||||
exceptiongroup==1.2.0
|
||||
importlib-metadata==7.1.0
|
||||
iniconfig==2.0.0
|
||||
maturin==1.5.1
|
||||
maturin==1.5.0
|
||||
packaging==24.0
|
||||
pluggy==1.4.0
|
||||
pytest==8.1.1
|
||||
@@ -11,4 +11,4 @@ typing-extensions==4.10.0
|
||||
zipp==3.18.1
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.17.1
|
||||
numpy==1.26.4
|
||||
numpy==1.26.4
|
||||
@@ -1,4 +1,6 @@
|
||||
// ignore file if compiling for wasm
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use clap::Parser;
|
||||
|
||||
@@ -956,6 +956,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
op.layout(self, region, values)
|
||||
let res = op.layout(self, region, values)?;
|
||||
|
||||
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
|
||||
if let Some(claimed_output) = &res {
|
||||
// during key generation this will be unknown vals so we use this as a flag to check
|
||||
let mut is_assigned = !claimed_output.any_unknowns()?;
|
||||
for val in values.iter() {
|
||||
is_assigned = is_assigned && !val.any_unknowns()?;
|
||||
}
|
||||
if is_assigned {
|
||||
op.safe_mode_check(claimed_output, values)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::i128_to_felt,
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -29,15 +29,15 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
SumPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
normalized: bool,
|
||||
},
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
MaxPool2d {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
pool_dims: (usize, usize),
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -85,6 +85,93 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
}
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => tensor::ops::nonlinearities::softmax_axes(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
axes,
|
||||
),
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater(&x, &y)?
|
||||
}
|
||||
HybridOp::GreaterEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Less => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less(&x, &y)?
|
||||
}
|
||||
HybridOp::LessEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Equals => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::equals(&x, &y)?
|
||||
}
|
||||
};
|
||||
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
@@ -114,12 +201,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
HybridOp::MaxPool {
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => format!(
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
@@ -166,9 +253,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
*padding,
|
||||
*stride,
|
||||
*kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
@@ -228,17 +315,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::MaxPool {
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => layouts::max_pool(
|
||||
} => layouts::max_pool2d(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
*padding,
|
||||
*stride,
|
||||
*pool_dims,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -136,11 +136,61 @@ impl LookupOp {
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "abs".into(),
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Sign => "sign".into(),
|
||||
LookupOp::LessThan { a } => format!("less_than_{}", a),
|
||||
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
|
||||
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::ReLU => "relu".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
|
||||
LookupOp::Sin { scale } => format!("sin_{}", scale),
|
||||
LookupOp::ASin { scale } => format!("asin_{}", scale),
|
||||
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
|
||||
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
|
||||
LookupOp::Tan { scale } => format!("tan_{}", scale),
|
||||
LookupOp::ATan { scale } => format!("atan_{}", scale),
|
||||
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
|
||||
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
|
||||
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
@@ -232,13 +282,6 @@ impl LookupOp {
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -35,6 +35,8 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
fn as_string(&self) -> String;
|
||||
|
||||
@@ -69,6 +71,33 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Safe mode output checl
|
||||
fn safe_mode_check(
|
||||
&self,
|
||||
claimed_output: &ValTensor<F>,
|
||||
original_values: &[ValTensor<F>],
|
||||
) -> Result<(), TensorError> {
|
||||
let felt_evals = original_values
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
|
||||
evals.reshape(v.dims())?;
|
||||
Ok(evals)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
|
||||
|
||||
let mut output = claimed_output
|
||||
.get_felt_evals()
|
||||
.map_err(|_| TensorError::FeltError)?;
|
||||
output.reshape(claimed_output.dims())?;
|
||||
|
||||
assert_eq!(output, ref_op);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
@@ -147,6 +176,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
self
|
||||
}
|
||||
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Input".into()
|
||||
}
|
||||
@@ -200,6 +235,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Err(TensorError::WrongMethod)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Unknown".into()
|
||||
@@ -269,6 +307,11 @@ impl<
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("CONST (scale={})", self.quantized_values.scale().unwrap())
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -31,8 +32,8 @@ pub enum PolyOp {
|
||||
equation: String,
|
||||
},
|
||||
Conv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -40,9 +41,9 @@ pub enum PolyOp {
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
padding: [(usize, usize); 2],
|
||||
output_padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -57,13 +58,10 @@ pub enum PolyOp {
|
||||
destination: usize,
|
||||
},
|
||||
Flatten(Vec<usize>),
|
||||
Pad(Vec<(usize, usize)>),
|
||||
Pad([(usize, usize); 2]),
|
||||
Sum {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
MeanOfSquares {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Prod {
|
||||
axes: Vec<usize>,
|
||||
len_prod: usize,
|
||||
@@ -107,28 +105,10 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => format!(
|
||||
"GATHERND (batch_dims={}, constant_idx{})",
|
||||
batch_dims,
|
||||
indices.is_some()
|
||||
),
|
||||
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
|
||||
PolyOp::ScatterElements { dim, constant_idx } => format!(
|
||||
"SCATTERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
|
||||
}
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -140,26 +120,15 @@ impl<
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
PolyOp::Add => "ADD".into(),
|
||||
PolyOp::Mult => "MULT".into(),
|
||||
PolyOp::Sub => "SUB".into(),
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
)
|
||||
}
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
|
||||
@@ -173,6 +142,146 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut inputs = inputs.to_vec();
|
||||
let res = match &self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"multibroadcastto inputs".to_string(),
|
||||
));
|
||||
}
|
||||
inputs[0].expand(shape)
|
||||
}
|
||||
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
|
||||
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
|
||||
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
|
||||
PolyOp::Not => tensor::ops::not(&inputs[0]),
|
||||
PolyOp::Downsample {
|
||||
axis,
|
||||
stride,
|
||||
modulo,
|
||||
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => inputs[0].move_axis(*source, *destination),
|
||||
PolyOp::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::Pad(p) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pad inputs".to_string()));
|
||||
}
|
||||
tensor::ops::pad(&inputs[0], *p)
|
||||
}
|
||||
PolyOp::Add => tensor::ops::add(&inputs),
|
||||
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
|
||||
PolyOp::Sub => tensor::ops::sub(&inputs),
|
||||
PolyOp::Mult => tensor::ops::mult(&inputs),
|
||||
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
}
|
||||
inputs[0].pow(*u)
|
||||
}
|
||||
PolyOp::Sum { axes } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("sum inputs".to_string()));
|
||||
}
|
||||
tensor::ops::sum_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("prod inputs".to_string()));
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
@@ -183,9 +292,6 @@ impl<
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
PolyOp::MeanOfSquares { axes } => {
|
||||
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
|
||||
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
|
||||
@@ -212,7 +318,7 @@ impl<
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -264,9 +370,9 @@ impl<
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
*padding,
|
||||
*output_padding,
|
||||
*stride,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
@@ -282,7 +388,7 @@ impl<
|
||||
)));
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(p.clone(), 0)?;
|
||||
input.pad(*p)?;
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
@@ -298,7 +404,6 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::{
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
@@ -25,6 +27,13 @@ pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
/// an optional directory to read and write the lookup table cache
|
||||
static ref LOOKUP_CACHE: Option<std::path::PathBuf> = std::env::var("LOOKUP_CACHE")
|
||||
.ok()
|
||||
.map(std::path::PathBuf::from);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
///
|
||||
pub struct SelectorConstructor<F: PrimeField> {
|
||||
@@ -111,10 +120,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let chunk = chunk as i128;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
.unwrap();
|
||||
let op_f = Op::<F>::f(
|
||||
&self.nonlinearity,
|
||||
&[Tensor::from(vec![first_element].into_iter())],
|
||||
)
|
||||
.unwrap();
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
@@ -202,8 +212,46 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(i128_to_felt(x)))?;
|
||||
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
|
||||
Ok((inputs, evals.output))
|
||||
};
|
||||
|
||||
let (inputs, evals) = if let Some(cache) = &*LOOKUP_CACHE {
|
||||
let cache_path = cache.join(self.nonlinearity.as_path());
|
||||
let input_path = cache_path.join("inputs");
|
||||
let output_path = cache_path.join("outputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading lookup table from cache: {:?}", cache_path);
|
||||
let (input_cache, output_cache) =
|
||||
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
|
||||
(input_cache, output_cache)
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let (inputs, evals) = gen_table()?;
|
||||
inputs.save(&input_path)?;
|
||||
evals.save(&output_path)?;
|
||||
|
||||
(inputs, evals)
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table {} without cache",
|
||||
self.nonlinearity.as_path()
|
||||
);
|
||||
|
||||
gen_table()?
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
@@ -235,7 +283,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
)?;
|
||||
}
|
||||
|
||||
let output = evals.output[row_offset];
|
||||
let output = evals[row_offset];
|
||||
|
||||
table.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
@@ -273,6 +321,11 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
format!("rangecheck_{}_{}", self.range.0, self.range.1)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i128;
|
||||
@@ -350,7 +403,31 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let inputs: Tensor<F> = if let Some(cache) = &*LOOKUP_CACHE {
|
||||
let cache_path = cache.join(self.as_path());
|
||||
let input_path = cache_path.join("inputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading range check table from cache: {:?}", cache_path);
|
||||
Tensor::load(&input_path)?
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating range check table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
inputs.save(&input_path)?;
|
||||
inputs
|
||||
}
|
||||
} else {
|
||||
log::info!("Generating range check {} without cache", self.as_path());
|
||||
|
||||
Tensor::from(smallest..=largest).map(|x| i128_to_felt(x))
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -1048,8 +1048,8 @@ mod conv {
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -455,7 +455,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
|
||||
for val in flattened_instances.clone() {
|
||||
let bytes = val.to_repr();
|
||||
let u = U256::from_little_endian(bytes.as_slice());
|
||||
let u = U256::from_little_endian(bytes.inner());
|
||||
public_inputs.push(u);
|
||||
}
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -636,7 +637,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn gen_witness(
|
||||
pub(crate) async fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
@@ -659,7 +660,7 @@ pub(crate) fn gen_witness(
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ use std::io::BufWriter;
|
||||
use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::thread;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
@@ -232,15 +234,21 @@ impl PostgresSource {
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config, NoTls)?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[])? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
|
||||
let mut client = Client::connect(&config, NoTls).unwrap();
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).unwrap() {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
})
|
||||
.join()
|
||||
.map_err(|_| "failed to fetch data from postgres")?;
|
||||
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
|
||||
@@ -918,7 +918,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn load_graph_input(
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
@@ -928,6 +928,7 @@ impl GraphCircuit {
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -951,7 +952,7 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
fn process_data_source(
|
||||
async fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
@@ -964,16 +965,8 @@ impl GraphCircuit {
|
||||
for (i, shape) in shapes.iter().enumerate() {
|
||||
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
|
||||
}
|
||||
|
||||
// start runtime and fetch data
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
})
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
}
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
|
||||
@@ -1200,20 +1200,6 @@ impl Model {
|
||||
.collect();
|
||||
|
||||
for (idx, node) in self.graph.nodes.iter() {
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1225,11 +1211,25 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input dims {:?}",
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -60,6 +61,20 @@ impl Op<Fp> for Rescaled {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
|
||||
let res = (ri.clone() * mult_tensor)?;
|
||||
rescaled_inputs.push(res);
|
||||
}
|
||||
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
@@ -200,6 +215,13 @@ impl Op<Fp> for RebaseScale {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
@@ -367,6 +389,13 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
fn f(
|
||||
&self,
|
||||
inputs: &[Tensor<Fp>],
|
||||
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
|
||||
self.as_op().f(inputs)
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
|
||||
@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -734,19 +734,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Sum { axes })
|
||||
}
|
||||
"Reduce<MeanOfSquares>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"mean of squares".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
|
||||
}
|
||||
|
||||
"Max" => {
|
||||
// Extract the max value
|
||||
// first find the input that is a constant
|
||||
@@ -1119,7 +1106,17 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1127,10 +1124,26 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
stride: (stride_h, stride_w),
|
||||
pool_dims: (kernel_height, kernel_width),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
@@ -1152,7 +1165,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
@@ -1192,7 +1205,15 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
Some(s) => {
|
||||
if s.len() == 1 {
|
||||
(s[0], s[0])
|
||||
} else if s.len() == 2 {
|
||||
(s[0], s[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
@@ -1200,7 +1221,17 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let padding = match &conv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1255,20 +1286,33 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
Some(s) => (s[0], s[1]),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
|
||||
let output_padding: (usize, usize) =
|
||||
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1287,7 +1331,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
output_padding,
|
||||
stride,
|
||||
})
|
||||
}
|
||||
@@ -1388,17 +1432,46 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"kernel shape".to_string(),
|
||||
)));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
stride: (stride_h, stride_w),
|
||||
kernel_shape: (kernel_height, kernel_width),
|
||||
normalized: sumpool_node.normalize,
|
||||
})
|
||||
}
|
||||
@@ -1425,7 +1498,29 @@ pub fn new_op_from_onnx(
|
||||
)));
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
|
||||
let padding_len = pad_node.pads.len();
|
||||
|
||||
// we only support symmetrical padding that affects the last 2 dims (height and width params)
|
||||
for (i, pad_params) in pad_node.pads.iter().enumerate() {
|
||||
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
"ezkl currently only supports padding height and width dimensions"
|
||||
.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let padding = [
|
||||
(
|
||||
pad_node.pads[padding_len - 2].0,
|
||||
pad_node.pads[padding_len - 1].0,
|
||||
),
|
||||
(
|
||||
pad_node.pads[padding_len - 2].1,
|
||||
pad_node.pads[padding_len - 1].1,
|
||||
),
|
||||
];
|
||||
SupportedOp::Linear(PolyOp::Pad(padding))
|
||||
}
|
||||
"RmAxis" | "Reshape" | "AddAxis" => {
|
||||
// Extract the slope layer hyperparams
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
)]
|
||||
// we allow this for our dynamic range based indexing scheme
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
|
||||
#![feature(buf_read_has_data_left)]
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ pub fn init_logger() {
|
||||
prefix_token(&record.level()),
|
||||
// pretty print UTC time
|
||||
chrono::Utc::now()
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.format("%Y-%m-%d %H:%M:%S:%3f")
|
||||
.to_string()
|
||||
.bright_magenta(),
|
||||
record.metadata().target(),
|
||||
|
||||
@@ -550,7 +550,8 @@ where
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
Scheme::Curve: Serialize + DeserializeOwned + SerdeObject,
|
||||
Scheme::ParamsProver: Send + Sync,
|
||||
{
|
||||
let strategy = Strategy::new(params.verifier_params());
|
||||
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
|
||||
688
src/python.rs
688
src/python.rs
@@ -33,7 +33,6 @@ use tokio::runtime::Runtime;
|
||||
|
||||
type PyFelt = String;
|
||||
|
||||
/// pyclass representing an enum
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
enum PyTestDataSource {
|
||||
@@ -52,17 +51,14 @@ impl From<PyTestDataSource> for TestDataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// pyclass containing the struct used for G1, this is mostly a helper class
|
||||
/// pyclass containing the struct used for G1
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyG1 {
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing x
|
||||
x: PyFelt,
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing y
|
||||
y: PyFelt,
|
||||
/// Field Element representing y
|
||||
#[pyo3(get, set)]
|
||||
z: PyFelt,
|
||||
}
|
||||
@@ -138,59 +134,39 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
}
|
||||
}
|
||||
|
||||
/// Python class containing the struct used for run_args
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// PyRunArgs
|
||||
///
|
||||
/// pyclass containing the struct used for run_args
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing parameters
|
||||
pub param_scale: crate::Scale,
|
||||
#[pyo3(get, set)]
|
||||
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
pub scale_rebase_multiplier: u32,
|
||||
#[pyo3(get, set)]
|
||||
/// list[int]: The min and max elements in the lookup table input column
|
||||
pub lookup_range: crate::circuit::table::Range,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The log_2 number of rows
|
||||
pub logrows: u32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The number of inner columns used for the lookup table
|
||||
pub num_inner_cols: usize,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub input_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub output_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub param_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// str: check mode, accepts `safe`, `unsafe`
|
||||
pub check_mode: CheckMode,
|
||||
#[pyo3(get, set)]
|
||||
/// str: commitment type, accepts `kzg`, `ipa`
|
||||
pub commitment: PyCommitments,
|
||||
}
|
||||
|
||||
@@ -250,7 +226,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
/// pyclass representing an enum, denoting the type of commitment
|
||||
/// Pyclass marking the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
KZG,
|
||||
@@ -297,19 +273,7 @@ impl FromStr for PyCommitments {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to big endian
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
/// field element represented as a string
|
||||
///
|
||||
/// Converts a felt to big endian
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
@@ -319,45 +283,22 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to an integer
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// int
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
array,
|
||||
))]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&array);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to a floating point number
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// scale: float
|
||||
/// The scaling factor used to convert the field element into a floating point representation
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// float
|
||||
///
|
||||
/// Converts a field eleement hex string to a floating point number
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
array,
|
||||
scale
|
||||
))]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&array);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
@@ -365,23 +306,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
}
|
||||
|
||||
/// Converts a floating point element to a field element hex string
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// input: float
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// scale: float
|
||||
/// The scaling factor used to quantize the float into a field element
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input,
|
||||
scale
|
||||
input,
|
||||
scale
|
||||
))]
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
@@ -391,20 +318,9 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of field elements
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// buffer: list[int]
|
||||
/// List of integers representing a buffer
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[str]
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
buffer
|
||||
))]
|
||||
))]
|
||||
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
@@ -463,20 +379,9 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
}
|
||||
|
||||
/// Generate a poseidon hash.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[str]
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
))]
|
||||
))]
|
||||
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
@@ -497,31 +402,12 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
}
|
||||
|
||||
/// Generate a kzg commitment.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the Structure Reference String (SRS) file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[PyG1Affine]
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
))]
|
||||
fn kzg_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -556,31 +442,12 @@ fn kzg_commit(
|
||||
}
|
||||
|
||||
/// Generate an ipa commitment.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the Structure Reference String (SRS) file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[PyG1Affine]
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
))]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -615,19 +482,10 @@ fn ipa_commit(
|
||||
}
|
||||
|
||||
/// Swap the commitments in a proof
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// proof_path: str
|
||||
/// Path to the proof file
|
||||
///
|
||||
/// witness_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
witness_path=PathBuf::from(DEFAULT_WITNESS),
|
||||
))]
|
||||
))]
|
||||
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
|
||||
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
|
||||
@@ -636,27 +494,11 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
|
||||
}
|
||||
|
||||
/// Generates a vk from a pk for a model circuit and saves it to a file
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// path_to_pk: str
|
||||
/// Path to the proving key
|
||||
///
|
||||
/// circuit_settings_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// vk_output_path: str
|
||||
/// Path to create the vk file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK),
|
||||
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK),
|
||||
))]
|
||||
))]
|
||||
fn gen_vk_from_pk_single(
|
||||
path_to_pk: PathBuf,
|
||||
circuit_settings_path: PathBuf,
|
||||
@@ -678,22 +520,10 @@ fn gen_vk_from_pk_single(
|
||||
}
|
||||
|
||||
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// path_to_pk: str
|
||||
/// Path to the proving key
|
||||
///
|
||||
/// vk_output_path: str
|
||||
/// Path to create the vk file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
))]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
@@ -708,17 +538,6 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
|
||||
}
|
||||
|
||||
/// Displays the table as a string in python
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// Returns
|
||||
/// ---------
|
||||
/// str
|
||||
/// Table of the nodes in the onnx file
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
py_run_args = None
|
||||
@@ -734,16 +553,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the Structured Reference String (SRS), use this only for testing purposes
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// srs_path: str
|
||||
/// Path to the create the SRS file
|
||||
///
|
||||
/// logrows: int
|
||||
/// The number of logrows for the SRS file
|
||||
///
|
||||
/// generates the srs
|
||||
#[pyfunction(signature = (
|
||||
srs_path,
|
||||
logrows,
|
||||
@@ -754,26 +564,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets a public srs
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// logrows: int
|
||||
/// The number of logrows for the SRS file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the create the SRS file
|
||||
///
|
||||
/// commitment: str
|
||||
/// Specify the commitment used ("kzg", "ipa")
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// gets a public srs
|
||||
#[pyfunction(signature = (
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
logrows=None,
|
||||
@@ -806,23 +597,7 @@ fn get_srs(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Generates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the settings file
|
||||
///
|
||||
/// py_run_args: PyRunArgs
|
||||
/// PyRunArgs object to initialize the settings
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// generates the circuit settings
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -843,38 +618,7 @@ fn gen_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Calibrates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data: str
|
||||
/// Path to the calibration data
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// settings: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// lookup_safety_margin: int
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
///
|
||||
/// scales: list[int]
|
||||
/// Optional scales to specifically try for calibration
|
||||
///
|
||||
/// scale_rebase_multiplier: list[int]
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
|
||||
///
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// calibrates the circuit settings
|
||||
#[pyfunction(signature = (
|
||||
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -916,30 +660,7 @@ fn calibrate_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the forward pass operation to generate a witness
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data: str
|
||||
/// Path to the data file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the witness file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// dict
|
||||
/// Python object containing the witness values
|
||||
///
|
||||
/// runs the forward pass operation
|
||||
#[pyfunction(signature = (
|
||||
data=PathBuf::from(DEFAULT_DATA),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -954,28 +675,19 @@ fn gen_witness(
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<PyObject> {
|
||||
let output =
|
||||
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
|
||||
let output = Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::gen_witness(
|
||||
model, data, output, vk_path, srs_path,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
/// Mocks the prover
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// witness: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// mocks the prover
|
||||
#[pyfunction(signature = (
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -988,23 +700,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Mocks the aggregate prover
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the relevant proof files
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use for the aggregation circuit
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Indicates whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// mocks the aggregate prover
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
@@ -1023,32 +719,7 @@ fn mock_aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the setup process
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to create the verification key file
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to create the proving key file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// witness_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// disable_selector_compression: bool
|
||||
/// Whether to compress the selectors or not
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// runs the prover on a set of inputs
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
@@ -1081,32 +752,7 @@ fn setup(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the prover on a set of inputs
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// witness: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to the proving key file
|
||||
///
|
||||
/// proof_path: str
|
||||
/// Path to create the proof file
|
||||
///
|
||||
/// proof_type: str
|
||||
/// Accepts `single`, `for-aggr`
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// runs the prover on a set of inputs
|
||||
#[pyfunction(signature = (
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -1140,29 +786,7 @@ fn prove(
|
||||
Python::with_gil(|py| Ok(snark.to_object(py)))
|
||||
}
|
||||
|
||||
/// Verifies a given proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof_path: str
|
||||
/// Path to create the proof file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// non_reduced_srs: bool
|
||||
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// verifies a given proof
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -1192,38 +816,6 @@ fn verify(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the setup process for an aggregate setup
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// sample_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to create the aggregated VK
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to create the aggregated PK
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// disable_selector_compression: bool
|
||||
/// Whether to compress selectors
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts `kzg`, `ipa`
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
@@ -1262,23 +854,6 @@ fn setup_aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compiles the circuit for use in other steps
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx model file
|
||||
///
|
||||
/// compiled_circuit: str
|
||||
/// Path to output the compiled circuit
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings files
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
@@ -1297,41 +872,7 @@ fn compile_circuit(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates an aggregated proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// proof_path: str
|
||||
/// Path to output the aggregated proof
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the VK file
|
||||
///
|
||||
/// transcript:
|
||||
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
///
|
||||
/// logrows:
|
||||
/// Logrows used for aggregation circuit
|
||||
///
|
||||
/// check_mode: str
|
||||
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
///
|
||||
/// split-proofs: bool
|
||||
/// Whether the accumulated proofs are segments of a larger circuit
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS used
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// creates an aggregated proof
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
@@ -1374,32 +915,7 @@ fn aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Verifies and aggregate proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof_path: str
|
||||
/// The path to the proof file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// logrows: int
|
||||
/// logrows used for aggregation circuit
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// reduced_srs: bool
|
||||
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// verifies and aggregate proof
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
@@ -1432,32 +948,7 @@ fn verify_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifier
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
#[pyfunction(signature = (
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -1490,26 +981,7 @@ fn create_evm_verifier(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// input_data: str
|
||||
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifier
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
#[pyfunction(signature = (
|
||||
input_data=PathBuf::from(DEFAULT_DATA),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -1531,32 +1003,6 @@ fn create_evm_data_attestation(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Setup test evm witness
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data_path: str
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
///
|
||||
/// compiled_circuit_path: str
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
///
|
||||
/// test_data: str
|
||||
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
///
|
||||
/// input_sources: str
|
||||
/// Where the input data comes from
|
||||
///
|
||||
/// output_source: str
|
||||
/// Where the output data comes from
|
||||
///
|
||||
/// rpc_url: str
|
||||
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
@@ -1591,7 +1037,6 @@ fn setup_test_evm_witness(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
@@ -1624,7 +1069,6 @@ fn deploy_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity vk verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
@@ -1657,7 +1101,6 @@ fn deploy_vk_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity da verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
input_data,
|
||||
@@ -1695,27 +1138,6 @@ fn deploy_da_evm(
|
||||
Ok(true)
|
||||
}
|
||||
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// addr_verifier: str
|
||||
/// The path to verifier contract's address
|
||||
///
|
||||
/// proof_path: str
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
///
|
||||
/// rpc_url: str
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
///
|
||||
/// addr_da: str
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
addr_verifier,
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
@@ -1761,35 +1183,7 @@ fn verify_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_settings: str
|
||||
/// path to the settings file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to load the desired verification key file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the Solidity code
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to output the Solidity verifier ABI
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows used during aggregated setup
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
#[pyfunction(signature = (
|
||||
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
|
||||
@@ -14,6 +14,9 @@ use maybe_rayon::{
|
||||
slice::ParallelSliceMut,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::BufRead;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
@@ -32,6 +35,7 @@ use halo2_proofs::{
|
||||
use itertools::Itertools;
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
@@ -60,9 +64,12 @@ pub enum TensorError {
|
||||
/// Unsupported operation
|
||||
#[error("Unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("Unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
/// File save error
|
||||
#[error("save error: {0}")]
|
||||
FileSaveError(String),
|
||||
/// File load error
|
||||
#[error("load error: {0}")]
|
||||
FileLoadError(String),
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
@@ -469,6 +476,45 @@ impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
/// save to a file
|
||||
pub fn save(&self, path: &PathBuf) -> Result<(), TensorError> {
|
||||
let writer =
|
||||
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
|
||||
let mut buf_writer = std::io::BufWriter::new(writer);
|
||||
|
||||
self.inner.iter().map(|x| x.clone()).for_each(|x| {
|
||||
let x = x.to_repr();
|
||||
buf_writer.write_all(x.as_ref()).unwrap();
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// load from a file
|
||||
pub fn load(path: &PathBuf) -> Result<Self, TensorError> {
|
||||
let reader =
|
||||
std::fs::File::open(path).map_err(|e| TensorError::FileLoadError(e.to_string()))?;
|
||||
let mut buf_reader = std::io::BufReader::new(reader);
|
||||
|
||||
let mut inner = Vec::new();
|
||||
while let Ok(true) = buf_reader.has_data_left() {
|
||||
let mut repr = T::Repr::default();
|
||||
match buf_reader.read_exact(repr.as_mut()) {
|
||||
Ok(_) => {
|
||||
inner.push(T::from_repr(repr).unwrap());
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Tensor::new(Some(&inner), &[inner.len()]).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Sets (copies) the tensor values to the provided ones.
|
||||
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
|
||||
@@ -937,7 +983,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
|
||||
assert!(source < self.dims.len());
|
||||
assert!(destination < self.dims.len());
|
||||
|
||||
let mut new_dims = self.dims.clone();
|
||||
new_dims.remove(source);
|
||||
new_dims.insert(destination, self.dims[source]);
|
||||
@@ -969,8 +1014,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
old_coord[source - 1] = *c;
|
||||
} else if (i < source && source < destination)
|
||||
|| (i < destination && source > destination)
|
||||
|| (i > source && source > destination)
|
||||
|| (i > destination && source < destination)
|
||||
{
|
||||
old_coord[i] = *c;
|
||||
} else if i > source && source < destination {
|
||||
@@ -983,10 +1026,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let value = self.get(&old_coord);
|
||||
|
||||
output.set(&coord, value);
|
||||
output.set(&coord, self.get(&old_coord));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
|
||||
2116
src/tensor/ops.rs
2116
src/tensor/ops.rs
File diff suppressed because it is too large
Load Diff
@@ -316,12 +316,6 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
|
||||
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
|
||||
pub fn new_instance(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -879,13 +873,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
/// Calls `pad_spatial_dims` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
|
||||
/// Calls `pad` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = pad(v, padding, offset)?;
|
||||
*v = pad(v, padding)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
|
||||
Reference in New Issue
Block a user