Compare commits

...

26 Commits

Author SHA1 Message Date
github-actions[bot]
7498aaf69d ci: update version string in docs 2024-05-03 00:25:32 +00:00
dante
749e0ba652 chore: update h2 solidity verifier (#787) 2024-05-03 01:25:14 +01:00
dante
d464ddf6b6 chore: medium sized lstm example (#785) 2024-05-01 16:35:11 +01:00
dante
8f6c0aced5 chore: update tract (#784) 2024-04-30 13:31:33 +01:00
dante
860e9700a8 refactor!: swap integer rep to i64 from i128 (#781)
BREAKING CHANGE: may break w/ old compiled circuits
2024-04-26 16:16:55 -04:00
Ethan Cemer
32dd4a854f fix: patch npm package build failure (#782) 2024-04-25 10:38:06 -04:00
dante
924f7c0420 fix: simplify kzg-commit (#780) 2024-04-24 11:57:20 -04:00
dante
ae03b6515b fix: update vis settings on help (#779) 2024-04-22 16:23:19 -04:00
Ethan Cemer
bae2e9e22b feat: kzgCommit wasm method (#778) 2024-04-18 19:22:11 -04:00
dante
4a93d31869 fix: accomodate modules in col-overflow (#777) 2024-04-18 17:13:31 -04:00
dante
88dd83dbe5 fix: default compiled model paths in python (#776) 2024-04-15 12:01:21 -04:00
Ethan Cemer
f05f83481e chore: update eth postgres (#769)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-04-13 08:08:09 -04:00
Ethan Cemer
8aaf518b5e fix: fix @ezkljs/verify etherumjs deps (#765) 2024-04-12 18:24:59 -04:00
katsumata
1b7b43e073 fix: Improve EZKL installation script reliability (#774) 2024-04-09 16:07:39 -04:00
dante
f78618ec59 feat: full ND conv and pool (#770) 2024-04-06 23:29:30 +01:00
Jseam
0943e534ee docs: automated sphinx documentation for python bindings (#714)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
Co-authored-by: Ethan Cemer <tylercemer@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-04-05 18:33:06 +01:00
dante
316a9a3b40 chore: update tract (#766) 2024-04-04 18:07:08 +01:00
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
dante
7fe179b8d4 feat: dictionary of reusable constants (#754) 2024-03-26 13:12:09 +00:00
Ethan Cemer
3be988a6a0 fix: use pnpm in build script for in-browser-evm-verifier (#752) 2024-03-25 23:23:02 +00:00
86 changed files with 8410 additions and 4159 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
name: Build and Publish EZKL Engine npm package
on:
workflow_dispatch:
@@ -62,7 +62,7 @@ jobs:
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
@@ -79,6 +79,10 @@ jobs:
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
- name: Add serialize and deserialize methods to nodejs bundle
run: |
echo '
@@ -174,40 +178,3 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -128,6 +128,7 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -139,6 +140,20 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
args: --release --out dist --features python-bindings
before-script-linux: |
# If we're running on rhel centos, install needed packages.
if command -v yum &> /dev/null; then
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
# If we're running on i686 we need to symlink libatomic
# in order to build openssl with -latomic flag.
if [[ ! -d "/usr/lib64" ]]; then
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
fi
else
# If we're running on debian-based system.
apt update -y && apt-get install -y libssl-dev openssl pkg-config
fi
- name: Install built wheel
if: matrix.target == 'x86_64'
@@ -162,7 +177,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -214,7 +229,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -249,7 +264,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -273,7 +288,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -345,3 +360,17 @@ jobs:
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}

View File

@@ -184,7 +184,7 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -207,7 +207,7 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -224,7 +224,7 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -281,7 +281,7 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -307,8 +307,8 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
pnpm install --frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -354,7 +354,7 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -380,7 +380,7 @@ jobs:
cache: "pnpm"
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
pnpm install --frozen-lockfile
env:
CI: false
NODE_ENV: development
@@ -394,14 +394,18 @@ jobs:
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (hashed inputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
- name: IPA prove and verify tests
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests single inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
- name: KZG prove and verify tests triple inner col
@@ -412,8 +416,6 @@ jobs:
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public inputs)
@@ -460,7 +462,7 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -495,7 +497,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -512,7 +514,7 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -545,8 +547,6 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: Download MNIST
run: sh data.sh
- name: Examples
run: cargo nextest run --release tests_examples
@@ -557,16 +557,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
@@ -576,12 +578,12 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -592,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -608,11 +610,29 @@ jobs:
python-integration-tests:
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -626,10 +646,16 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +671,3 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -14,6 +14,40 @@ jobs:
- uses: actions/checkout@v4
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.1
uses: mathieudutour/github-tag-action@v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Set Cargo.toml version to match github tag for docs
shell: bash
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
mv docs/python/src/conf.py docs/python/src/conf.py.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
rm docs/python/src/conf.py.orig
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
rm docs/python/requirements-docs.txt.orig
- name: Commit files and create tag
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git fetch --tags
git checkout -b release-$RELEASE_TAG
git add .
git commit -m "ci: update version string in docs"
git tag -d $RELEASE_TAG
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:
branch: release-${{ steps.tag_version.outputs.new_tag }}
force: true
tags: true

65
.github/workflows/verify.yml vendored Normal file
View File

@@ -0,0 +1,65 @@
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name#v }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Fetch integrity
run: |
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@${{ github.ref_name#v }} dist.integrity)
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
- name: Update pnpm-lock.yaml versions and integrity
run: |
awk -v integrity="$ENGINE_INTEGRITY" -v tag="${{ github.ref_name#v }}" '
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

5
.gitignore vendored
View File

@@ -1,6 +1,5 @@
target
pkg
data
*.csv
!examples/notebooks/eth_price.csv
*.ipynb_checkpoints
@@ -48,4 +47,6 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.1

26
.readthedocs.yaml Normal file
View File

@@ -0,0 +1,26 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: ./docs/python/src/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: ./docs/python/requirements-docs.txt

146
Cargo.lock generated
View File

@@ -655,6 +655,12 @@ dependencies = [
"constant_time_eq",
]
[[package]]
name = "block"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
[[package]]
name = "block-buffer"
version = "0.9.0"
@@ -1011,6 +1017,17 @@ version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "core-graphics-types"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"libc",
]
[[package]]
name = "cpufeatures"
version = "0.2.12"
@@ -1307,6 +1324,12 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
[[package]]
name = "dyn-hash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
[[package]]
name = "ecc"
version = "0.1.0"
@@ -1793,8 +1816,10 @@ dependencies = [
"lazy_static",
"log",
"maybe-rayon",
"metal",
"mnist",
"num",
"objc",
"openssl",
"pg_bigdecimal",
"plotters",
@@ -1932,7 +1957,28 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
"foreign-types-shared 0.1.1",
]
[[package]]
name = "foreign-types"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [
"foreign-types-macros",
"foreign-types-shared 0.3.1",
]
[[package]]
name = "foreign-types-macros"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
]
[[package]]
@@ -1941,6 +1987,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "foreign-types-shared"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
@@ -2224,7 +2276,7 @@ dependencies = [
[[package]]
name = "halo2_solidity_verifier"
version = "0.1.0"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#eb04be1f7d005e5b9dd3ff41efa30aeb5e0c34a3"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#fd74f1da2ce51664e2d4349965987ee606551060"
dependencies = [
"askama",
"blake2b_simd",
@@ -2699,6 +2751,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.10"
@@ -2924,6 +2985,15 @@ dependencies = [
"subtle",
]
[[package]]
name = "malloc_buf"
version = "0.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb"
dependencies = [
"libc",
]
[[package]]
name = "maplit"
version = "1.0.2"
@@ -2968,9 +3038,9 @@ checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "memmap2"
version = "0.5.10"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327"
checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322"
dependencies = [
"libc",
]
@@ -2984,6 +3054,20 @@ dependencies = [
"autocfg",
]
[[package]]
name = "metal"
version = "0.27.0"
source = "git+https://github.com/gfx-rs/metal-rs#ff8fd3d6dc7792852f8a015458d7e6d42d7fb352"
dependencies = [
"bitflags 2.5.0",
"block",
"core-graphics-types",
"foreign-types 0.5.0",
"log",
"objc",
"paste",
]
[[package]]
name = "mime"
version = "0.3.17"
@@ -3196,6 +3280,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "objc"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1"
dependencies = [
"malloc_buf",
]
[[package]]
name = "object"
version = "0.32.2"
@@ -3256,7 +3349,7 @@ checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [
"bitflags 2.5.0",
"cfg-if",
"foreign-types",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
@@ -4601,9 +4694,9 @@ dependencies = [
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",
@@ -4845,12 +4938,12 @@ checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
[[package]]
name = "string-interner"
version = "0.14.0"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648"
checksum = "07f9fdfdd31a0ff38b59deb401be81b73913d76c9cc5b1aed4e1330a223420b9"
dependencies = [
"cfg-if",
"hashbrown 0.11.2",
"hashbrown 0.14.3",
"serde",
]
@@ -5389,8 +5482,8 @@ dependencies = [
[[package]]
name = "tract-core"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"anyhow",
"bit-set",
@@ -5413,12 +5506,14 @@ dependencies = [
[[package]]
name = "tract-data"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"anyhow",
"downcast-rs",
"dyn-hash",
"half 2.2.1",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"maplit",
"ndarray",
@@ -5432,8 +5527,8 @@ dependencies = [
[[package]]
name = "tract-hir"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"derive-new",
"log",
@@ -5442,13 +5537,14 @@ dependencies = [
[[package]]
name = "tract-linalg"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"cc",
"derive-new",
"downcast-rs",
"dyn-clone",
"dyn-hash",
"half 2.2.1",
"lazy_static",
"liquid",
@@ -5466,8 +5562,8 @@ dependencies = [
[[package]]
name = "tract-nnef"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"byteorder",
"flate2",
@@ -5480,8 +5576,8 @@ dependencies = [
[[package]]
name = "tract-onnx"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"bytes",
"derive-new",
@@ -5497,8 +5593,8 @@ dependencies = [
[[package]]
name = "tract-onnx-opl"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.5-pre"
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
dependencies = [
"getrandom",
"log",

View File

@@ -43,6 +43,7 @@ unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
# evm related deps
@@ -80,9 +81,11 @@ pyo3-asyncio = { version = "0.20.0", features = [
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true }
@@ -95,10 +98,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
@@ -198,6 +201,7 @@ det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']

View File

@@ -20,9 +20,9 @@
"name": "quantize_data",
"outputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"stateMutability": "pure",
@@ -31,9 +31,9 @@
{
"inputs": [
{
"internalType": "int128[]",
"internalType": "int64[]",
"name": "quantized_data",
"type": "int128[]"
"type": "int64[]"
}
],
"name": "to_field_element",

View File

@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: [(0, 0); 2],
stride: (1, 1),
padding: vec![(0, 0)],
stride: vec![1; 2],
}),
)
.unwrap();

View File

@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: [(0, 0); 2],
stride: (1, 1),
kernel_shape: (2, 2),
padding: vec![(0, 0); 2],
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
}),
)

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
@@ -48,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
}
}

View File

@@ -125,7 +125,7 @@ contract QuantizeData {
}
function to_field_element(
int128[] memory quantized_data
int64[] memory quantized_data
) public pure returns (uint256[] memory output) {
output = new uint256[](quantized_data.length);
for (uint i; i < quantized_data.length; i++) {

11
data.sh
View File

@@ -1,11 +0,0 @@
#! /bin/bash
mkdir data
cd data
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gzip -d *.gz

2
docs/python/build.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
sphinx-build ./src build

View File

@@ -0,0 +1,4 @@
ezkl==11.0.3
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

29
docs/python/src/conf.py Normal file
View File

@@ -0,0 +1,29 @@
import ezkl
project = 'ezkl'
release = '11.0.3'
version = release
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
]
autosummary_generate = True
autosummary_imported_members = True
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']

11
docs/python/src/index.rst Normal file
View File

@@ -0,0 +1,11 @@
.. extension documentation master file, created by
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
ezkl python bindings
================================================
.. automodule:: ezkl
:members:
:undoc-members:

View File

@@ -42,8 +42,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -66,8 +66,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -90,8 +90,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -203,8 +203,8 @@ where
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let op = PolyOp::Conv {
padding: [(PADDING, PADDING); 2],
stride: (STRIDE, STRIDE),
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
};
let x = config
.layer_config
@@ -308,6 +308,7 @@ pub fn runconv() {
tst_lbl: _,
..
} = MnistBuilder::new()
.base_path("examples/data")
.label_format_digit()
.training_set_length(50_000)
.validation_set_length(10_000)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: i128,
const LOOKUP_MAX: i128,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: i128, const LOOKUP_MAX: i128> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -7,9 +7,9 @@
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [e2pg](https://github.com/indexsupply/x/tree/main/docs/e2pg), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://postgresapp.com/. \n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
@@ -21,23 +21,81 @@
"source": [
"import os\n",
"import getpass\n",
"\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/main/linux/amd64/e2pg\")\n",
"os.system(\"chmod +x e2pg\")\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgresql://\" + getpass.getuser() + \":@localhost:5432/e2pg\"\n",
"os.environ[\"RLPS_URL\"] = \"https://1.rlps.indexsupply.net\"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"os.system(\"echo $RLPS_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 e2pg\")\n",
"# equivalent of nohup ./e2pg -reset -e $RLPS_URL -pg $PG_URL &\n",
"e2pg_process = os.system(\"nohup ./e2pg -e $RLPS_URL -pg $PG_URL &\")\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(5)\n",
"\n"
]
},
@@ -79,11 +137,13 @@
"import json\n",
"import os\n",
"\n",
"# import logging\n",
"# # # uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)"
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
@@ -176,6 +236,7 @@
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
@@ -183,9 +244,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 5\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -194,7 +255,7 @@
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_input_file, open(input_filename, 'w' ))\n"
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
@@ -210,9 +271,9 @@
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"e2pg\",\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT value FROM erc20_transfers ORDER BY block_number DESC LIMIT 20\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
@@ -229,22 +290,6 @@
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -253,10 +298,21 @@
},
"outputs": [],
"source": [
"# setup kzg params\n",
"params_path = os.path.join('kzg.params')\n",
"import subprocess\n",
"import os\n",
"\n",
"res = ezkl.get_srs(params_path, settings_filename)"
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
@@ -306,16 +362,13 @@
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"params_path = os.path.join('kzg.params')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path,\n",
" params_path,\n",
" settings_filename,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
@@ -331,11 +384,14 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
"# generate the witness\n",
"res = ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
@@ -360,73 +416,14 @@
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" params_path,\n",
" \"single\",\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n",
"# verify\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_filename,\n",
" vk_path,\n",
" params_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7tAa-DFAtvS"
},
"source": [
"# Part 2 (Using the ZK Computational Graph Onchain!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Ym91kaVAIB6"
},
"source": [
"**Now How Do We Do It Onchain?????**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 339
},
"id": "fodkNgwS70FM",
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
},
"outputs": [],
"source": [
"# first we need to create evm verifier\n",
"print(vk_path)\n",
"print(params_path)\n",
"print(settings_filename)\n",
"\n",
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" params_path,\n",
" settings_filename,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
"\n"
]
},
{
@@ -435,51 +432,8 @@
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the address from addr_path\n",
"addr = None\n",
"with open(address_path, 'r') as f:\n",
" addr = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.system(\"killall -9 e2pg\");"
"# kill all shovel process \n",
"os.system(\"pkill -f shovel\")"
]
}
],
@@ -501,7 +455,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

View File

@@ -0,0 +1,9 @@
{
"input_data": [
[
1.514470100402832, 1.519423007965088, 1.5182757377624512,
1.5262789726257324, 1.5298409461975098
]
],
"output_data": [[-0.1862019]]
}

Binary file not shown.

View File

@@ -1,6 +1,6 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"version": "v10.4.2",
"publishConfig": {
"access": "public"
},
@@ -17,19 +17,19 @@
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
"@ethereumjs/common": "4.0.0",
"@ethereumjs/evm": "2.0.0",
"@ethereumjs/statemanager": "2.0.0",
"@ethereumjs/tx": "5.0.0",
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ezkljs/engine": "10.4.2",
"ethers": "6.7.1",
"json-bigint": "1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",

View File

@@ -6,34 +6,34 @@ settings:
dependencies:
'@ethereumjs/common':
specifier: ^4.0.0
specifier: 4.0.0
version: 4.0.0
'@ethereumjs/evm':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/statemanager':
specifier: ^2.0.0
specifier: 2.0.0
version: 2.0.0
'@ethereumjs/tx':
specifier: ^5.0.0
specifier: 5.0.0
version: 5.0.0
'@ethereumjs/util':
specifier: ^9.0.0
specifier: 9.0.0
version: 9.0.0
'@ethereumjs/vm':
specifier: ^7.0.0
specifier: 7.0.0
version: 7.0.0
'@ethersproject/abi':
specifier: ^5.7.0
specifier: 5.7.0
version: 5.7.0
'@ezkljs/engine':
specifier: ^9.4.4
version: 9.4.4
specifier: "10.4.2"
version: "10.4.2"
ethers:
specifier: ^6.7.1
specifier: 6.7.1
version: 6.7.1
json-bigint:
specifier: ^1.0.0
specifier: 1.0.0
version: 1.0.0
devDependencies:
@@ -397,8 +397,8 @@ packages:
'@ethersproject/strings': 5.7.0
dev: false
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
/@ezkljs/engine@10.4.2:
resolution: {integrity: "sha512-1GNB4vChbaQ1ALcYbEbM/AFoh4QWtswpzGCO/g9wL8Ep6NegM2gQP/uWICU7Utl0Lj1DncXomD7PUhFSXhtx8A=="}
dependencies:
'@types/json-bigint': 1.0.2
json-bigint: 1.0.0

View File

@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
exit 1
fi
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.14,<0.15"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.pytest.ini_options]

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.1
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -15,6 +15,8 @@ pub use planner::*;
use crate::tensor::{TensorType, ValTensor};
use super::region::ConstantsMap;
/// Module trait used to extend ezkl functionality
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Config
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
/// Layout
fn layout(
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;

View File

@@ -4,6 +4,8 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
@@ -13,6 +15,7 @@ use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
@@ -41,12 +44,11 @@ impl PolyCommitChip {
/// Commit to the message using the KZG commitment scheme
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
message: Vec<Scheme::Scalar>,
degree: u32,
num_unusable_rows: u32,
params: &Scheme::ParamsProver,
) -> Vec<G1Affine> {
let k = params.k();
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
let domain = halo2_proofs::poly::EvaluationDomain::new(2, k);
let n = 2_u64.pow(k) - num_unusable_rows as u64;
let num_poly = (message.len() / n as usize) + 1;
let mut poly = vec![domain.empty_lagrange(); num_poly];
@@ -107,6 +109,7 @@ impl Module<Fp> for PolyCommitChip {
&self,
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
Ok(())
}
@@ -119,11 +122,24 @@ impl Module<Fp> for PolyCommitChip {
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
}
@@ -184,7 +200,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
polycommit_chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
);
Ok(())
}

View File

@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::Module;
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
&self,
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
let local_constants = constants.clone();
let res = layouter.assign_region(
|| "load message",
|mut region| {
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
Ok(res)
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -434,7 +453,7 @@ mod tests {
*,
};
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_proofs::{
@@ -477,7 +496,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
)?;
Ok(())
}

View File

@@ -24,7 +24,7 @@ use crate::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {
@@ -956,20 +956,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
let res = op.layout(self, region, values)?;
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
if let Some(claimed_output) = &res {
// during key generation this will be unknown vals so we use this as a flag to check
let mut is_assigned = !claimed_output.any_unknowns()?;
for val in values.iter() {
is_assigned = is_assigned && !val.any_unknowns()?;
}
if is_assigned {
op.safe_mode_check(claimed_output, values)?;
}
}
};
Ok(res)
op.layout(self, region, values)
}
}

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::i64_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -29,15 +29,15 @@ pub enum HybridOp {
dim: usize,
},
SumPool {
padding: [(usize, usize); 2],
stride: (usize, usize),
kernel_shape: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
},
MaxPool2d {
padding: [(usize, usize); 2],
stride: (usize, usize),
pool_dims: (usize, usize),
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
},
ReduceMin {
axes: Vec<usize>,
@@ -46,7 +46,8 @@ pub enum HybridOp {
dim: usize,
},
Softmax {
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
@@ -70,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -84,86 +85,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax { scale, axes } => {
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
}
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater(&x, &y)?
}
HybridOp::GreaterEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater_equal(&x, &y)?
}
HybridOp::Less => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less(&x, &y)?
}
HybridOp::LessEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less_equal(&x, &y)?
}
HybridOp::Equals => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::equals(&x, &y)?
}
};
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
match self {
@@ -193,18 +114,25 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => format!(
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
HybridOp::Softmax { scale, axes } => {
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
@@ -238,9 +166,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values[..].try_into()?,
*padding,
*stride,
*kernel_shape,
padding,
stride,
kernel_shape,
*normalized,
)?,
HybridOp::Recip {
@@ -256,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values[..].try_into()?,
i128_to_felt(input_scale.0 as i128),
i128_to_felt(output_scale.0 as i128),
i64_to_felt(input_scale.0 as i64),
i64_to_felt(output_scale.0 as i64),
)?
} else {
layouts::nonlinearity(
@@ -281,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
i64_to_felt(denom.0 as i64),
)?
} else {
layouts::nonlinearity(
@@ -300,17 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
}
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => layouts::max_pool2d(
} => layouts::max_pool(
config,
region,
values[..].try_into()?,
*padding,
*stride,
*pool_dims,
padding,
stride,
pool_dims,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
@@ -324,9 +252,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::ReduceArgMin { dim } => {
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
}
HybridOp::Softmax { scale, axes } => {
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => layouts::softmax_axes(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
@@ -359,8 +296,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,9 @@ use std::error::Error;
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::{felt_to_i64, i64_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
};
use super::Op;
@@ -132,19 +132,16 @@ impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
let range = range as i64;
(-range, range)
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i64(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
@@ -231,10 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
}
}?;
let output = res.map(|x| i128_to_felt(x));
let output = res.map(|x| i64_to_felt(x));
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Returns the name of the operation
fn as_string(&self) -> String {

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -27,14 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
std::fmt::Debug + Send + Sync + Any
{
/// Returns a string representation of the operation.
fn as_string(&self) -> String;
@@ -69,36 +69,9 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any;
/// Safe mode output checl
fn safe_mode_check(
&self,
claimed_output: &ValTensor<F>,
original_values: &[ValTensor<F>],
) -> Result<(), TensorError> {
let felt_evals = original_values
.iter()
.map(|v| {
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
evals.reshape(v.dims())?;
Ok(evals)
})
.collect::<Result<Vec<_>, _>>()?;
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
let mut output = claimed_output
.get_felt_evals()
.map_err(|_| TensorError::FeltError)?;
output.reshape(claimed_output.dims())?;
assert_eq!(output, ref_op);
Ok(())
}
}
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -149,8 +122,8 @@ impl InputType {
*input = T::from_f64(f64_input).unwrap();
}
InputType::Int | InputType::TDim => {
let int_input = input.clone().to_i128().unwrap();
*input = T::from_i128(int_input).unwrap();
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
}
}
@@ -165,7 +138,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -174,12 +147,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
self
}
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
})
}
fn as_string(&self) -> String {
"Input".into()
}
@@ -226,16 +193,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Err(TensorError::WrongMethod)
}
fn as_string(&self) -> String {
"Unknown".into()
@@ -256,7 +220,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
///
pub quantized_values: Tensor<F>,
///
@@ -266,7 +230,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -293,17 +257,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
+ IntoI64,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
format!("CONST (scale={})", self.quantized_values.scale().unwrap())

View File

@@ -1,6 +1,5 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -32,8 +31,8 @@ pub enum PolyOp {
equation: String,
},
Conv {
padding: [(usize, usize); 2],
stride: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
},
Downsample {
axis: usize,
@@ -41,9 +40,9 @@ pub enum PolyOp {
modulo: usize,
},
DeConv {
padding: [(usize, usize); 2],
output_padding: (usize, usize),
stride: (usize, usize),
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
},
Add,
Sub,
@@ -58,10 +57,13 @@ pub enum PolyOp {
destination: usize,
},
Flatten(Vec<usize>),
Pad([(usize, usize); 2]),
Pad(Vec<(usize, usize)>),
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
@@ -89,8 +91,15 @@ pub enum PolyOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for PolyOp
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
+ IntoI64,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -99,10 +108,28 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -114,15 +141,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
PolyOp::Add => "ADD".into(),
PolyOp::Mult => "MULT".into(),
PolyOp::Sub => "SUB".into(),
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
@@ -136,146 +174,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
"multibroadcastto inputs".to_string(),
));
}
inputs[0].expand(shape)
}
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
PolyOp::Not => tensor::ops::not(&inputs[0]),
PolyOp::Downsample {
axis,
stride,
modulo,
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::MoveAxis {
source,
destination,
} => inputs[0].move_axis(*source, *destination),
PolyOp::Flatten(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::Pad(p) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pad inputs".to_string()));
}
tensor::ops::pad(&inputs[0], *p)
}
PolyOp::Add => tensor::ops::add(&inputs),
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult => tensor::ops::mult(&inputs),
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
PolyOp::DeConv {
padding,
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
}
inputs[0].pow(*u)
}
PolyOp::Sum { axes } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("sum inputs".to_string()));
}
tensor::ops::sum_axes(&inputs[0], axes)
}
PolyOp::Prod { axes, .. } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("prod inputs".to_string()));
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
PolyOp::Slice { axis, start, end } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
}?;
Ok(ForwardResult { output: res })
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
@@ -286,6 +184,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
@@ -312,7 +213,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -364,9 +265,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
config,
region,
values[..].try_into()?,
*padding,
*output_padding,
*stride,
padding,
output_padding,
stride,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -382,7 +283,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
)));
}
let mut input = values[0].clone();
input.pad(*p)?;
input.pad(p.clone(), 0)?;
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
@@ -398,6 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {

View File

@@ -2,24 +2,28 @@ use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
collections::HashSet,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
@@ -120,26 +124,47 @@ impl From<Box<dyn std::error::Error>> for RegionError {
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
max_lookup_inputs: i64,
min_lookup_inputs: i64,
max_range_size: i64,
witness_gen: bool,
check_lookup_range: bool,
assigned_constants: ConstantsMap<F>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
///
pub fn increment_total_constants(&mut self, n: usize) {
self.total_constants += n;
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
self.max_lookup_inputs().to_string().green(),
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
}
///
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
&self.assigned_constants
}
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants);
}
///
@@ -163,8 +188,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
pub fn witness_gen(&self) -> bool {
self.witness_gen
}
///
pub fn check_lookup_range(&self) -> bool {
self.check_lookup_range
}
/// Create a new region context
@@ -177,7 +207,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -185,9 +214,23 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: true,
check_lookup_range: true,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_with_constants(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
@@ -202,7 +245,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
@@ -210,7 +252,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: false,
check_lookup_range: false,
assigned_constants: HashMap::new(),
}
}
@@ -218,7 +262,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -228,7 +273,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -236,17 +280,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy_with_constants(
pub fn new_dummy_with_linear_coord(
row: usize,
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -254,7 +300,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -262,7 +307,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
@@ -312,29 +359,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(), RegionError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
self.throw_range_check_error,
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -343,10 +389,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
constants.fetch_add(
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
@@ -362,11 +404,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -410,6 +454,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
Ok(())
}
@@ -435,7 +487,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
}
let range_size = (range.1 - range.0).abs();
@@ -477,7 +529,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the total number of constants
pub fn total_constants(&self) -> usize {
self.total_constants
self.assigned_constants.len()
}
/// Get the dynamic lookup index
@@ -511,40 +563,38 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
pub fn max_lookup_inputs(&self) -> i64 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
pub fn min_lookup_inputs(&self) -> i64 {
self.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i128 {
pub fn max_range_size(&self) -> i64 {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
if let Some(region) = &self.region {
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
Ok(cell.into())
} else {
Ok(value.into())
}
}
/// Assign a valtensor to a vartensor
pub fn assign(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -560,14 +610,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -594,13 +648,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
} else {
self.total_constants += values.num_constants();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -615,24 +676,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len, total_assigned_constants) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((res, len))
} else {
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((values.clone(), len))
}
}
@@ -699,9 +760,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
Ok(())
}
/// increment constants
pub fn increment_constants(&mut self, n: usize) {
self.total_constants += n
}
}

View File

@@ -11,19 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
circuit::CircuitError,
fieldutils::i128_to_felt,
tensor::{Tensor, TensorType},
fieldutils::i64_to_felt,
tensor::{IntoI64, Tensor, TensorType},
};
use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
pub type Range = (i64, i64);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
@@ -98,26 +96,25 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
i128_to_felt(chunk)
i64_to_felt(chunk)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
let chunk = chunk as i128;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
let op_f = Op::<F>::f(
&self.nonlinearity,
&[Tensor::from(vec![first_element].into_iter())],
)
.unwrap();
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
.unwrap();
(first_element, op_f.output[0])
}
@@ -133,12 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
}
///
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
(range_len / (col_size as i64)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -205,8 +202,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -275,12 +272,12 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
}
///
@@ -297,13 +294,13 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
i128_to_felt(chunk)
i64_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
@@ -353,7 +350,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -1048,8 +1048,8 @@ mod conv {
&mut region,
&self.inputs,
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis);
@@ -1911,6 +1911,8 @@ mod add_with_overflow {
#[cfg(test)]
mod add_with_overflow_and_poseidon {
use std::collections::HashMap;
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
@@ -1969,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
let assigned_inputs_a =
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
let assigned_inputs_b =
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
layouter.assign_region(|| "_new_module", |_| Ok(()))?;

View File

@@ -345,7 +345,7 @@ pub enum Commands {
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
lookup_safety_margin: i128,
lookup_safety_margin: i64,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
scales: Option<Vec<crate::Scale>>,
@@ -444,7 +444,7 @@ pub enum Commands {
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -479,7 +479,7 @@ pub enum Commands {
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -726,7 +726,7 @@ pub enum Commands {
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -424,7 +424,7 @@ pub async fn setup_test_contract<M: 'static + Middleware>(
let input = input.to_float() as f32;
let decimal_places = count_decimal_places(input) as u8;
let scaled_by_decimals = input * f32::powf(10., decimal_places.into());
scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i128));
scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i64));
decimals.push(decimal_places);
} else if input.is_field() {
let input = input.to_field(0);

View File

@@ -24,6 +24,8 @@ use crate::pfsys::{
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
@@ -194,7 +196,6 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
srs_path,
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(model, witness),
#[cfg(not(target_arch = "wasm32"))]
@@ -337,7 +338,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
split_proofs,
disable_selector_compression,
commitment,
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -358,7 +359,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
check_mode,
split_proofs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -382,7 +383,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
reduced_srs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -538,7 +539,7 @@ fn check_srs_hash(
let path = get_srs_path(logrows, srs_path, commitment);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
let predefined_hash = match crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) {
Some(h) => h,
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
@@ -584,7 +585,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -635,7 +636,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
Ok(String::new())
}
pub(crate) async fn gen_witness(
pub(crate) fn gen_witness(
compiled_circuit_path: PathBuf,
data: PathBuf,
output: Option<PathBuf>,
@@ -658,33 +659,30 @@ pub(crate) async fn gen_witness(
};
#[cfg(not(target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data).await?;
let mut input = circuit.load_graph_input(&data)?;
#[cfg(target_arch = "wasm32")]
let mut input = circuit.load_graph_input(&data)?;
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
true,
)?
}
Commitments::IPA => {
@@ -692,22 +690,29 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
true,
)?
}
}
} else {
warn!("SRS for poly commit does not exist (will be ignored)");
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
vk.as_ref(),
None,
true,
true,
)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -819,7 +824,15 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let percentage_error = error.enum_map(|i, x| {
// if everything is 0 then we can't divide by 0 so we just return 0
let res = if original[i] == 0.0 && x == 0.0 {
0.0
} else {
x / original[i]
};
Ok::<f32, TensorError>(res)
})?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error);
@@ -882,12 +895,13 @@ pub(crate) fn calibrate(
data: PathBuf,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: i128,
lookup_safety_margin: i64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use log::error;
use std::collections::HashMap;
use tabled::Table;
@@ -900,9 +914,9 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
info!("running onnx predictions...");
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
@@ -970,10 +984,18 @@ pub(crate) fn calibrate(
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
let mut num_failed = 0;
let mut num_passed = 0;
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (
@@ -989,6 +1011,7 @@ pub(crate) fn calibrate(
param_scale,
scale_rebase_multiplier,
div_rebasing,
lookup_range: (i64::MIN, i64::MAX),
..settings.run_args.clone()
};
@@ -1007,7 +1030,9 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
debug!("circuit creation from run args failed: {:?}", e);
error!("circuit creation from run args failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
};
@@ -1022,7 +1047,13 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.forward::<KZGCommitmentScheme<Bn256>>(&mut data.clone(), None, None, true)
.forward::<KZGCommitmentScheme<Bn256>>(
&mut data.clone(),
None,
None,
true,
false,
)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
@@ -1037,9 +1068,11 @@ pub(crate) fn calibrate(
match forward_res {
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
// typically errors will be due to the circuit overflowing the i64 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
}
@@ -1104,8 +1137,10 @@ pub(crate) fn calibrate(
"found settings: \n {}",
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
debug!("calibration failed {}", res.err().unwrap());
error!("calibration failed {}", res.err().unwrap());
num_failed += 1;
}
pb.inc(1);
@@ -1208,22 +1243,14 @@ pub(crate) fn calibrate(
);
if matches!(target, CalibrationTarget::Resources { col_overflow: true }) {
let lookup_log_rows = ((best_params.run_args.lookup_range.1
- best_params.run_args.lookup_range.0) as f32)
.log2()
.ceil() as u32
+ 1;
let mut reduction = std::cmp::max(
(best_params
.model_instance_shapes
.iter()
.map(|x| x.iter().product::<usize>())
.sum::<usize>() as f32)
.log2()
.ceil() as u32
+ 1,
lookup_log_rows,
);
let lookup_log_rows = best_params.lookup_log_rows_with_blinding();
let module_log_row = best_params.module_constraint_logrows_with_blinding();
let instance_logrows = best_params.log2_total_instances_with_blinding();
let dynamic_lookup_logrows = best_params.dynamic_lookup_and_shuffle_logrows_with_blinding();
let mut reduction = std::cmp::max(lookup_log_rows, module_log_row);
reduction = std::cmp::max(reduction, instance_logrows);
reduction = std::cmp::max(reduction, dynamic_lookup_logrows);
reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS);
info!(
@@ -1278,17 +1305,19 @@ pub(crate) fn create_evm_verifier(
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1322,17 +1351,18 @@ pub(crate) fn create_evm_vk(
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1601,8 +1631,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1711,7 +1742,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1720,7 +1752,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -1879,7 +1911,9 @@ pub(crate) fn mock_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1922,7 +1956,9 @@ pub(crate) fn setup_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG",).into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1983,7 +2019,9 @@ pub(crate) fn aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -2156,8 +2194,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -11,8 +11,8 @@ pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
}
}
/// Converts an i128 to a PrimeField element.
pub fn i128_to_felt<F: PrimeField>(x: i128) -> F {
/// Converts an i64 to a PrimeField element.
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
@@ -37,7 +37,7 @@ pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
/// Converts a PrimeField element to an f64.
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
if x > F::from_u128(i128::MAX as u128) {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -50,18 +50,18 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
}
}
/// Converts a PrimeField element to an i128.
pub fn felt_to_i128<F: PrimeField + PartialOrd + Field>(x: F) -> i128 {
if x > F::from_u128(i128::MAX as u128) {
/// Converts a PrimeField element to an i64.
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
-(lower_128 as i128)
-(lower_128 as i64)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
lower_128 as i128
lower_128 as i64
}
}
@@ -79,10 +79,10 @@ mod test {
let res: F = i32_to_felt(2_i32.pow(17));
assert_eq!(res, F::from(131072));
let res: F = i128_to_felt(-15i128);
let res: F = i64_to_felt(-15i64);
assert_eq!(res, -F::from(15));
let res: F = i128_to_felt(2_i128.pow(17));
let res: F = i64_to_felt(2_i64.pow(17));
assert_eq!(res, F::from(131072));
}
@@ -96,10 +96,10 @@ mod test {
}
#[test]
fn felttoi128() {
for x in -(2i128.pow(20))..(2i128.pow(20)) {
let fieldx: F = i128_to_felt::<F>(x);
let xf: i128 = felt_to_i128::<F>(fieldx);
fn felttoi64() {
for x in -(2i64.pow(20))..(2i64.pow(20)) {
let fieldx: F = i64_to_felt::<F>(x);
let xf: i64 = felt_to_i64::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -1,7 +1,7 @@
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
use crate::fieldutils::i64_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
@@ -21,8 +21,6 @@ use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
@@ -130,7 +128,7 @@ impl FileSourceInner {
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
@@ -152,7 +150,7 @@ impl FileSourceInner {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64,
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
}
}
}
@@ -234,21 +232,15 @@ impl PostgresSource {
)
};
let res: Vec<pg_bigdecimal::PgNumeric> = thread::spawn(move || {
let mut client = Client::connect(&config, NoTls).unwrap();
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).unwrap() {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
let mut client = Client::connect(&config, NoTls)?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[])? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
res
})
.join()
.map_err(|_| "failed to fetch data from postgres")?;
}
Ok(vec![res])
}

View File

@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
@@ -38,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -61,13 +62,13 @@ pub use vars::*;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -155,7 +156,7 @@ use std::cell::RefCell;
thread_local!(
/// This is a global variable that holds the settings for the graph
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
);
/// Result from a forward pass
@@ -174,11 +175,11 @@ pub struct GraphWitness {
/// Any hashes of outputs generated during the forward pass
pub processed_outputs: Option<ModuleForwardResult>,
/// max lookup input
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// max lookup input
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// max range check size
pub max_range_size: i128,
pub max_range_size: i64,
}
impl GraphWitness {
@@ -482,7 +483,22 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
/// Calc the number of rows required for lookup tables
pub fn lookup_log_rows(&self) -> u32 {
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32)
.log2()
.ceil() as u32
}
/// Calc the number of rows required for lookup tables
pub fn lookup_log_rows_with_blinding(&self) -> u32 {
((self.run_args.lookup_range.1 - self.run_args.lookup_range.0) as f32
+ RESERVED_BLINDING_ROWS as f32)
.log2()
.ceil() as u32
}
fn model_constraint_logrows_with_blinding(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
@@ -494,14 +510,31 @@ impl GraphSettings {
.ceil() as u32
}
/// calculate the number of rows required for the dynamic lookup and shuffle
pub fn dynamic_lookup_and_shuffle_logrows_with_blinding(&self) -> u32 {
(self.total_dynamic_col_size as f64
+ self.total_shuffle_col_size as f64
+ RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
fn module_constraint_logrows(&self) -> u32 {
/// calculate the number of rows required for the module constraints
pub fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
/// calculate the number of rows required for the module constraints
pub fn module_constraint_logrows_with_blinding(&self) -> u32 {
(self.module_sizes.max_constraints() as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
.log2()
@@ -528,6 +561,14 @@ impl GraphSettings {
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances_with_blinding(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>() + RESERVED_BLINDING_ROWS;
// max between 1 and the log2 of the sums
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// save params to file
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
// buf writer
@@ -917,7 +958,7 @@ impl GraphCircuit {
///
#[cfg(not(target_arch = "wasm32"))]
pub async fn load_graph_input(
pub fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
@@ -927,7 +968,6 @@ impl GraphCircuit {
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(target_arch = "wasm32")]
@@ -951,7 +991,7 @@ impl GraphCircuit {
#[cfg(not(target_arch = "wasm32"))]
/// Process the data source for the model
async fn process_data_source(
fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
@@ -964,8 +1004,16 @@ impl GraphCircuit {
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
// start runtime and fetch data
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async {
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
})
}
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
@@ -1050,16 +1098,14 @@ impl GraphCircuit {
Ok(data)
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let margin = (
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
margin
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
@@ -1067,7 +1113,7 @@ impl GraphCircuit {
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
max_range_size: i64,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
@@ -1086,9 +1132,9 @@ impl GraphCircuit {
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: i128,
max_range_size: i64,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
lookup_safety_margin: i64,
) -> Result<(), Box<dyn std::error::Error>> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
@@ -1127,7 +1173,7 @@ impl GraphCircuit {
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
@@ -1182,7 +1228,7 @@ impl GraphCircuit {
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
max_range_size: i64,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
@@ -1240,7 +1286,8 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
throw_range_check_error: bool,
witness_gen: bool,
check_lookup: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1289,7 +1336,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1452,7 +1499,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1462,7 +1510,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
@@ -1604,6 +1653,8 @@ impl Circuit<Fp> for GraphCircuit {
let output_vis = &self.settings().run_args.output_visibility;
let mut graph_modules = GraphModules::new();
let mut constants = ConstantsMap::new();
let mut config = config.clone();
let mut inputs = self
@@ -1649,6 +1700,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut input_outlets,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace inputs with the outlets
for (i, outlet) in outlets.iter().enumerate() {
@@ -1661,6 +1713,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut inputs,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
}
@@ -1697,6 +1750,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut flattened_params,
param_visibility,
&mut instance_offset,
&mut constants,
)?;
let shapes = self.model().const_shapes();
@@ -1725,6 +1779,7 @@ impl Circuit<Fp> for GraphCircuit {
&inputs,
&mut vars,
&outputs,
&mut constants,
)
.map_err(|e| {
log::error!("{}", e);
@@ -1749,6 +1804,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut output_outlets,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace outputs with the outlets
@@ -1762,6 +1818,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut outputs,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
}

View File

@@ -5,6 +5,7 @@ use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
@@ -64,11 +65,11 @@ pub struct ForwardResult {
/// The outputs of the forward pass.
pub outputs: Vec<Tensor<Fp>>,
/// The maximum value of any input to a lookup operation.
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// The max range check size
pub max_range_size: i128,
pub max_range_size: i64,
}
impl From<DummyPassRes> for ForwardResult {
@@ -116,11 +117,11 @@ pub struct DummyPassRes {
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i128,
pub max_lookup_inputs: i64,
/// min lookup inputs
pub min_lookup_inputs: i128,
pub min_lookup_inputs: i64,
/// min range check
pub max_range_size: i128,
pub max_range_size: i64,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -404,7 +405,7 @@ impl ParsedNodes {
.get(input)
.ok_or(GraphError::MissingNode(*input))?;
let input_dims = node.out_dims();
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
inputs.push(input_dim.clone());
}
@@ -514,27 +515,30 @@ impl Model {
instance_shapes.len().to_string().blue(),
"instances".blue()
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
let len = shape.iter().product();
let mut t: ValTensor<Fp> = (0..len)
.map(|_| {
if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs, false)?;
let res = self.dummy_layout(run_args, &inputs, false, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -577,13 +581,14 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
witness_gen: bool,
check_lookup: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
Ok(res.into())
}
@@ -799,13 +804,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -814,6 +824,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -1071,6 +1082,8 @@ impl Model {
/// * `layouter` - Halo2 Layouter.
/// * `inputs` - The values to feed into the circuit.
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1079,6 +1092,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
info!("model layout...");
@@ -1104,14 +1118,12 @@ impl Model {
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
let mut total_const_size = 0;
let original_constants = constants.clone();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1157,29 +1169,17 @@ impl Model {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
} else if !run_args.output_visibility.is_private() {
for output in &outputs {
thread_safe_region.increment_total_constants(output.num_constants());
}
}
num_rows = thread_safe_region.row();
linear_coord = thread_safe_region.linear_coord();
total_const_size = thread_safe_region.total_constants();
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
"rows".blue(),
linear_coord.to_string().yellow(),
total_const_size.to_string().red()
);
)?;
let duration = start_time.elapsed();
trace!("model layout took: {:?}", duration);
@@ -1201,6 +1201,20 @@ impl Model {
.collect();
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
@@ -1212,31 +1226,11 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1277,8 +1271,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1310,6 +1304,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1322,25 +1317,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}
@@ -1380,7 +1392,8 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
witness_gen: bool,
check_lookup: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1399,29 +1412,32 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let mut region =
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let default_value = if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
@@ -1432,7 +1448,7 @@ impl Model {
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
region.update_constants(output.create_constants_map());
}
}
@@ -1441,14 +1457,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
"rows".blue(),
region.linear_coord().to_string().yellow(),
region.total_constants().to_string().red()
);
region.debug_report();
let outputs = outputs
.iter()

View File

@@ -2,6 +2,7 @@ use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
@@ -211,12 +212,13 @@ impl GraphModules {
layouter: &mut impl Layouter<Fp>,
x: &mut Vec<ValTensor<Fp>>,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
// reserve module 0 for ... modules
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
x[0] = module
.layout(layouter, &cloned_x, instance_offset.to_owned())
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
.unwrap();
for inc in module.instance_increment_input().iter() {
// increment the instance offset to make way for future module layouts
@@ -234,6 +236,7 @@ impl GraphModules {
values: &mut [ValTensor<Fp>],
element_visibility: &Visibility,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
@@ -248,7 +251,7 @@ impl GraphModules {
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
// increment the current index
self.polycommit_idx += 1;
});
@@ -270,7 +273,7 @@ impl GraphModules {
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
});
// replace the inputs with the outputs
values.iter_mut().enumerate().for_each(|(i, x)| {
@@ -311,7 +314,6 @@ impl GraphModules {
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
let res = PolyCommitChip::commit::<Scheme>(
x.to_vec(),
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
srs,
);

View File

@@ -14,7 +14,6 @@ use crate::circuit::Op;
use crate::circuit::Unknown;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
@@ -61,20 +60,6 @@ impl Op<Fp> for Rescaled {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
if self.scale.len() != x.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
}
let mut rescaled_inputs = vec![];
let inputs = &mut x.to_vec();
for (i, ri) in inputs.iter_mut().enumerate() {
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
let res = (ri.clone() * mult_tensor)?;
rescaled_inputs.push(res);
}
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
}
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
@@ -215,13 +200,6 @@ impl Op<Fp> for RebaseScale {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
Ok(res)
}
fn as_string(&self) -> String {
format!(
@@ -389,13 +367,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
fn f(
&self,
inputs: &[Tensor<Fp>],
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
self.as_op().f(inputs)
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,

View File

@@ -52,16 +52,16 @@ use tract_onnx::tract_hir::{
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i128, TensorError> {
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Err(TensorError::SigBitTruncationError);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as i128;
let scaled = (mult * *elem + shift).round() as i64;
Ok(scaled)
}
@@ -72,7 +72,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
/// * `scale` - `2^scale` used in the fixed point representation.
/// * `shift` - offset used in the fixed point representation.
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
let int_rep = crate::fieldutils::felt_to_i128(felt);
let int_rep = crate::fieldutils::felt_to_i64(felt);
let multiplier = scale_to_multiplier(scale);
int_rep as f64 / multiplier - shift
}
@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -734,6 +734,19 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Sum { axes })
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
}
"Max" => {
// Extract the max value
// first find the input that is a constant
@@ -1072,8 +1085,12 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
})
}
@@ -1102,17 +1119,7 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1120,26 +1127,10 @@ pub fn new_op_from_onnx(
};
let kernel_shape = &pool_spec.kernel_shape;
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
};
SupportedOp::Hybrid(HybridOp::MaxPool2d {
SupportedOp::Hybrid(HybridOp::MaxPool {
padding,
stride: (stride_h, stride_w),
pool_dims: (kernel_height, kernel_width),
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
})
}
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
@@ -1161,7 +1152,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
}
@@ -1201,15 +1192,7 @@ pub fn new_op_from_onnx(
}
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => {
if s.len() == 1 {
(s[0], s[0])
} else if s.len() == 2 {
(s[0], s[1])
} else {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
}
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
@@ -1217,17 +1200,7 @@ pub fn new_op_from_onnx(
let padding = match &conv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1282,33 +1255,20 @@ pub fn new_op_from_onnx(
}
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => (s[0], s[1]),
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
};
let padding = match &deconv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let output_padding: (usize, usize) =
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
@@ -1327,7 +1287,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::DeConv {
padding,
output_padding,
output_padding: deconv_node.adjustments.to_vec(),
stride,
})
}
@@ -1428,46 +1388,17 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let kernel_shape = &pool_spec.kernel_shape;
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams(
"kernel shape".to_string(),
)));
};
SupportedOp::Hybrid(HybridOp::SumPool {
padding,
stride: (stride_h, stride_w),
kernel_shape: (kernel_height, kernel_width),
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
})
}
@@ -1494,29 +1425,7 @@ pub fn new_op_from_onnx(
)));
}
let padding_len = pad_node.pads.len();
// we only support symmetrical padding that affects the last 2 dims (height and width params)
for (i, pad_params) in pad_node.pads.iter().enumerate() {
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
return Err(Box::new(GraphError::MisformedParams(
"ezkl currently only supports padding height and width dimensions"
.to_string(),
)));
}
}
let padding = [
(
pad_node.pads[padding_len - 2].0,
pad_node.pads[padding_len - 1].0,
),
(
pad_node.pads[padding_len - 2].1,
pad_node.pads[padding_len - 1].1,
),
];
SupportedOp::Linear(PolyOp::Pad(padding))
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
}
"RmAxis" | "Reshape" | "AddAxis" => {
// Extract the slope layer hyperparams
@@ -1566,7 +1475,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
visibility: &Visibility,
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
Ok::<_, TensorError>(crate::fieldutils::i128_to_felt::<F>(quantize_float(
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,

View File

@@ -346,7 +346,7 @@ pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {

View File

@@ -23,7 +23,7 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(round_ties_even)]
#![feature(stmt_expr_attributes)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
@@ -115,6 +115,12 @@ pub enum Commitments {
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -184,7 +190,7 @@ pub struct RunArgs {
#[arg(long, default_value = "1")]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17")]
@@ -195,13 +201,13 @@ pub struct RunArgs {
/// Hand-written parser for graph variables, eg. batch_size=1
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, hashed
/// Flags whether inputs are public, private, hashed, fixed, kzgcommit
#[arg(long, default_value = "private")]
pub input_visibility: Visibility,
/// Flags whether outputs are public, private, hashed
/// Flags whether outputs are public, private, fixed, hashed, kzgcommit
#[arg(long, default_value = "public")]
pub output_visibility: Visibility,
/// Flags whether params are public, private, hashed
/// Flags whether params are fixed, private, hashed, kzgcommit
#[arg(long, default_value = "private")]
pub param_visibility: Visibility,
#[arg(long, default_value = "false")]
@@ -215,7 +221,7 @@ pub struct RunArgs {
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
pub commitment: Option<Commitments>,
}
impl Default for RunArgs {
@@ -235,7 +241,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
commitment: None,
}
}
}
@@ -243,6 +249,12 @@ impl Default for RunArgs {
impl RunArgs {
///
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
if self.param_visibility == Visibility::Public {
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
.into(),
);
}
if self.scale_rebase_multiplier < 1 {
return Err("scale_rebase_multiplier must be >= 1".into());
}

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@@ -0,0 +1,31 @@
[[kernel]]
void add(
constant long *inA [[buffer(0)]],
constant long *inB [[buffer(1)]],
device long *result [[buffer(2)]],
uint index [[thread_position_in_grid]])
{
result[index] = inA[index] + inB[index];
}
[[kernel]]
void sub(
constant long *inA [[buffer(0)]],
constant long *inB [[buffer(1)]],
device long *result [[buffer(2)]],
uint index [[thread_position_in_grid]])
{
result[index] = inA[index] - inB[index];
}
[[kernel]]
void mul(
constant long *inA [[buffer(0)]],
constant long *inB [[buffer(1)]],
device long *result [[buffer(2)]],
uint index [[thread_position_in_grid]])
{
result[index] = inA[index] * inB[index];
}

Binary file not shown.

View File

@@ -5,7 +5,7 @@ pub mod val;
/// A wrapper around a tensor of Halo2 Value types.
pub mod var;
use halo2curves::ff::PrimeField;
use halo2curves::{bn256::Fr, ff::PrimeField};
use maybe_rayon::{
prelude::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
@@ -17,9 +17,12 @@ use serde::{Deserialize, Serialize};
pub use val::*;
pub use var::*;
#[cfg(feature = "metal")]
use instant::Instant;
use crate::{
circuit::utils,
fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt},
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
graph::Visibility,
};
@@ -30,12 +33,18 @@ use halo2_proofs::{
poly::Rotation,
};
use itertools::Itertools;
#[cfg(feature = "metal")]
use metal::{Device, MTLResourceOptions, MTLSize};
use std::error::Error;
use std::fmt::Debug;
use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
use thiserror::Error;
#[cfg(feature = "metal")]
use std::collections::HashMap;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
pub enum TensorError {
@@ -60,6 +69,31 @@ pub enum TensorError {
/// Unsupported operation
#[error("Unsupported operation on a tensor type")]
Unsupported,
/// Overflow
#[error("Unsigned integer overflow or underflow error in op: {0}")]
Overflow(String),
}
#[cfg(feature = "metal")]
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
#[cfg(feature = "metal")]
lazy_static::lazy_static! {
static ref DEVICE: Device = Device::system_default().expect("no device found");
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
let mut map = HashMap::new();
for name in ["add", "sub", "mul"] {
let function = LIB.get_function(name, None).unwrap();
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
map.insert(name.to_string(), pipeline);
}
map
};
}
/// The (inner) type of tensor elements.
@@ -142,7 +176,7 @@ impl TensorType for f64 {
}
tensor_type!(bool, Bool, false, true);
tensor_type!(i128, Int128, 0, 1);
tensor_type!(i64, Int64, 0, 1);
tensor_type!(i32, Int32, 0, 1);
tensor_type!(usize, USize, 0, 1);
tensor_type!((), Empty, (), ());
@@ -308,6 +342,94 @@ impl<T: TensorType> DerefMut for Tensor<T> {
self.inner.deref_mut()
}
}
/// Convert to i64 trait
pub trait IntoI64 {
/// Convert to i64
fn into_i64(self) -> i64;
/// From i64
fn from_i64(i: i64) -> Self;
}
impl IntoI64 for i64 {
fn into_i64(self) -> i64 {
self
}
fn from_i64(i: i64) -> i64 {
i
}
}
impl IntoI64 for i32 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as i32
}
}
impl IntoI64 for usize {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as usize
}
}
impl IntoI64 for f32 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as f32
}
}
impl IntoI64 for f64 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as f64
}
}
impl IntoI64 for () {
fn into_i64(self) -> i64 {
0
}
fn from_i64(_: i64) -> Self {
()
}
}
impl IntoI64 for Fr {
fn into_i64(self) -> i64 {
felt_to_i64(self)
}
fn from_i64(i: i64) -> Self {
i64_to_felt::<Fr>(i)
}
}
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
fn into_i64(self) -> i64 {
let mut res = vec![];
self.map(|x| res.push(x.into_i64()));
if res.len() == 0 {
0
} else {
res[0]
}
}
fn from_i64(i: i64) -> Self {
Value::known(F::from_i64(i))
}
}
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
fn eq(&self, other: &Tensor<T>) -> bool {
@@ -378,7 +500,7 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
{
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<Value<F>> {
let mut output = Vec::new();
for (_, x) in value.iter().enumerate() {
for x in value.iter() {
output.push(x.value_field().evaluate());
}
Tensor::new(Some(&output), value.dims()).unwrap()
@@ -424,16 +546,28 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>>
}
}
impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>> {
fn from(t: Tensor<i128>) -> Tensor<Value<F>> {
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
let mut ta: Tensor<Value<F>> =
Tensor::from((0..t.len()).map(|i| Value::known(i128_to_felt::<F>(t[i]))));
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
// safe to unwrap as we know the dims are correct
ta.reshape(t.dims()).unwrap();
ta
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::FromParallelIterator<T> for Tensor<T>
{
fn from_par_iter<I>(par_iter: I) -> Self
where
I: maybe_rayon::iter::IntoParallelIterator<Item = T>,
{
let inner: Vec<T> = par_iter.into_par_iter().collect();
Tensor::new(Some(&inner), &[inner.len()]).unwrap()
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::IntoParallelIterator for Tensor<T>
{
@@ -922,6 +1056,7 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
assert!(source < self.dims.len());
assert!(destination < self.dims.len());
let mut new_dims = self.dims.clone();
new_dims.remove(source);
new_dims.insert(destination, self.dims[source]);
@@ -953,6 +1088,8 @@ impl<T: Clone + TensorType> Tensor<T> {
old_coord[source - 1] = *c;
} else if (i < source && source < destination)
|| (i < destination && source > destination)
|| (i > source && source > destination)
|| (i > destination && source < destination)
{
old_coord[i] = *c;
} else if i > source && source < destination {
@@ -965,7 +1102,10 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
}
output.set(&coord, self.get(&old_coord));
let value = self.get(&old_coord);
output.set(&coord, value);
}
Ok(output)
@@ -1196,6 +1336,97 @@ impl<T: Clone + TensorType> Tensor<T> {
}
}
#[cfg(feature = "metal")]
#[allow(unsafe_code)]
/// Perform a tensor operation on the GPU using Metal.
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
v: &Tensor<T>,
w: &Tensor<T>,
op: &str,
) -> Tensor<T> {
assert_eq!(v.dims(), w.dims());
log::trace!("------------------------------------------------");
let start = Instant::now();
let v = v
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
.unwrap();
let w = w
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
.unwrap();
log::trace!("Time to map tensors: {:?}", start.elapsed());
objc::rc::autoreleasepool(|| {
// create function pipeline.
// this compiles the function, so a pipline can't be created in performance sensitive code.
let pipeline = &PIPELINES[op];
let length = v.len() as u64;
let size = length * core::mem::size_of::<i64>() as u64;
assert_eq!(v.len(), w.len());
let start = Instant::now();
let buffer_a = DEVICE.new_buffer_with_data(
unsafe { std::mem::transmute(v.as_ptr()) },
size,
MTLResourceOptions::StorageModeShared,
);
let buffer_b = DEVICE.new_buffer_with_data(
unsafe { std::mem::transmute(w.as_ptr()) },
size,
MTLResourceOptions::StorageModeShared,
);
let buffer_result = DEVICE.new_buffer(
size, // the operation will return an array with the same size.
MTLResourceOptions::StorageModeShared,
);
log::trace!("Time to load buffers: {:?}", start.elapsed());
// for sending commands, a command buffer is needed.
let start = Instant::now();
let command_buffer = QUEUE.new_command_buffer();
log::trace!("Time to load command buffer: {:?}", start.elapsed());
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
let start = Instant::now();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(&pipeline);
compute_encoder.set_buffers(
0,
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
&[0; 3],
);
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
// specify thread count and organization
let start = Instant::now();
let grid_size = MTLSize::new(length, 1, 1);
let threadgroup_size = MTLSize::new(length, 1, 1);
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
// end encoding and execute commands
let start = Instant::now();
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
log::trace!("Time to commit: {:?}", start.elapsed());
let start = Instant::now();
let ptr = buffer_result.contents() as *const i64;
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
log::trace!("Time to get result: {:?}", start.elapsed());
res.map(|x| T::from_i64(x))
})
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
/// Flattens a tensor of tensors
/// ```
@@ -1217,7 +1448,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
}
}
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Adds tensors.
/// # Arguments
@@ -1267,14 +1500,24 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
/// ```
fn add(self, rhs: Self) -> Self::Output {
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() + r;
});
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "add");
Ok(lhs)
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
.zip(rhs)
.map(|(o, r)| o.clone() + r)
.collect();
res.reshape(&broadcasted_shape).unwrap();
res
};
Ok(res)
}
}
@@ -1297,6 +1540,7 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
/// ```
fn neg(self) -> Self {
let mut output = self;
output.par_iter_mut().for_each(|x| {
*x = x.clone().neg();
});
@@ -1304,7 +1548,9 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
}
}
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Subtracts tensors.
/// # Arguments
@@ -1355,18 +1601,30 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
/// ```
fn sub(self, rhs: Self) -> Self::Output {
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() - r;
});
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "sub");
Ok(lhs)
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
.zip(rhs)
.map(|(o, r)| o.clone() - r)
.collect();
res.reshape(&broadcasted_shape).unwrap();
res
};
Ok(res)
}
}
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Elementwise multiplies tensors.
/// # Arguments
@@ -1415,18 +1673,28 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
/// ```
fn mul(self, rhs: Self) -> Self::Output {
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() * r;
});
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "mul");
Ok(lhs)
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
.zip(rhs)
.map(|(o, r)| o.clone() * r)
.collect();
res.reshape(&broadcasted_shape).unwrap();
res
};
Ok(res)
}
}
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
/// Elementwise raise a tensor to the nth power.
/// # Arguments
///
@@ -1640,4 +1908,66 @@ mod tests {
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_int() {
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_felt() {
use halo2curves::bn256::Fr;
let a = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let b = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
&[2, 2],
)
.unwrap()
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,12 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -51,6 +55,24 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
}
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
/// Returns the inner cell of the [ValType].
pub fn cell(&self) -> Option<Cell> {
match self {
ValType::PrevAssigned(cell) => Some(cell.cell()),
ValType::AssignedConstant(cell, _) => Some(cell.cell()),
_ => None,
}
}
/// Returns the assigned cell of the [ValType].
pub fn assigned_cell(&self) -> Option<AssignedCell<F, F>> {
match self {
ValType::PrevAssigned(cell) => Some(cell.clone()),
ValType::AssignedConstant(cell, _) => Some(cell.clone()),
_ => None,
}
}
/// Returns true if the value is previously assigned.
pub fn is_prev_assigned(&self) -> bool {
matches!(
@@ -293,7 +315,13 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
inner.into()
}
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
pub fn new_instance(
cs: &mut ConstraintSystem<F>,
@@ -428,10 +456,37 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}
@@ -466,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Calls `int_evals` on the inner tensor.
pub fn get_int_evals(&self) -> Result<Tensor<i128>, Box<dyn Error>> {
pub fn get_int_evals(&self) -> Result<Tensor<i64>, Box<dyn Error>> {
// finally convert to vector of integers
let mut integer_evals: Vec<i128> = vec![];
let mut integer_evals: Vec<i64> = vec![];
match self {
ValTensor::Value {
inner: v, dims: _, ..
@@ -476,25 +531,25 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
let _ = v.map(|vaf| match vaf {
ValType::Value(v) => v.map(|f| {
integer_evals.push(crate::fieldutils::felt_to_i128(f));
integer_evals.push(crate::fieldutils::felt_to_i64(f));
}),
ValType::AssignedValue(v) => v.map(|f| {
integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate()));
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
}),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
v.value_field().map(|f| {
integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate()));
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
})
}
ValType::Constant(v) => {
integer_evals.push(crate::fieldutils::felt_to_i128(v));
integer_evals.push(crate::fieldutils::felt_to_i64(v));
Value::unknown()
}
});
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
@@ -824,13 +879,13 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
};
Ok(())
}
/// Calls `pad` on the inner [Tensor].
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
/// Calls `pad_spatial_dims` on the inner [Tensor].
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = pad(v, padding)?;
*v = pad(v, padding, offset)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::CheckMode;
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
@@ -289,9 +289,10 @@ impl VarTensor {
&self,
region: &mut Region<F>,
offset: usize,
coord: usize,
constant: F,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset);
let (x, y, z) = self.cartesian_coord(offset + coord);
match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant)
@@ -304,33 +305,28 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<&usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
}
ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>(
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok(k);
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?;
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
@@ -340,11 +336,12 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut res: ValTensor<F> = match values {
ValTensor::Instance {
@@ -382,14 +379,7 @@ impl VarTensor {
},
ValTensor::Value { inner: v, .. } => Ok(v
.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord)?;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
self.assign_value(region, offset, k.clone(), coord, constants)
})?
.into()),
}?;
@@ -399,13 +389,16 @@ impl VarTensor {
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn dummy_assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
@@ -430,21 +423,24 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
row: usize,
@@ -452,7 +448,8 @@ impl VarTensor {
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
match values {
@@ -494,7 +491,7 @@ impl VarTensor {
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
};
let cell = self.assign_value(region, offset, k.clone(), coord * step)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
if single_inner_col {
if z == 0 {
@@ -502,28 +499,23 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
if let Some(prev_cell) = prev_cell.as_ref() {
region.constrain_equal(prev_cell.cell(),cell.cell())?;
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}}
match k {
ValType::Constant(f) => {
Ok(ValType::AssignedConstant(cell, f))
},
ValType::AssignedConstant(_, f) => {
Ok(ValType::AssignedConstant(cell, f))
},
_ => {
Ok(ValType::PrevAssigned(cell))
}
}
Ok(cell)
})?.into()};
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
@@ -542,42 +534,61 @@ impl VarTensor {
)};
}
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
fn assign_value<F: PrimeField + TensorType + PartialOrd>(
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
k: ValType<F>,
coord: usize,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
match k {
let res = match k {
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice(|| "k", advices[x][y], z, || v)
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
},
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self {
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
v.copy_advice(|| "k", region, advices[x][y], z)
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => {
error!("PrevAssigned is only supported for advice columns");
Err(halo2_proofs::plonk::Error::Synthesis)
_ => unimplemented!(),
},
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
},
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => region
.assign_advice(|| "k", advices[x][y], z, || v)
.map(|a| a.evaluate()),
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
},
ValType::Constant(v) => self.assign_constant(region, offset + coord, v),
}
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(
self.assign_constant(region, offset, coord, v)?,
v,
);
e.insert(value.clone());
value
} else {
let cell = constants.get(&v).unwrap();
self.assign_value(region, offset, cell.clone(), coord, constants)?
}
}
};
Ok(res)
}
}

View File

@@ -1,19 +1,22 @@
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::PoseidonChip;
use crate::circuit::modules::Module;
use crate::fieldutils::felt_to_i128;
use crate::fieldutils::i128_to_felt;
use crate::fieldutils::felt_to_i64;
use crate::fieldutils::i64_to_felt;
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +36,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -111,7 +113,7 @@ pub fn feltToInt(
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&felt_to_i128(felt))
serde_json::to_vec(&felt_to_i64(felt))
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
))
}
@@ -125,7 +127,7 @@ pub fn feltToFloat(
) -> Result<f64, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let int_rep = felt_to_i128(felt);
let int_rep = felt_to_i64(felt);
let multiplier = scale_to_multiplier(scale);
Ok(int_rep as f64 / multiplier)
}
@@ -139,13 +141,51 @@ pub fn floatToFelt(
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = i128_to_felt(int_rep);
let felt = i64_to_felt(int_rep);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Generate a kzg commitment.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn kzgCommit(
message: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let mut reader = std::io::BufReader::new(&params_ser[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&params,
);
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
@@ -235,7 +275,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false)
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)
@@ -336,8 +376,94 @@ pub fn verify(
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
@@ -436,8 +562,9 @@ pub fn prove(
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)

View File

@@ -3,7 +3,7 @@
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
@@ -122,7 +122,7 @@ mod native_tests {
let settings: GraphSettings = serde_json::from_str(&settings).unwrap();
let logrows = settings.run_args.logrows;
download_srs(logrows, settings.run_args.commitment);
download_srs(logrows, settings.run_args.commitment.into());
}
fn mv_test_(test_dir: &str, test: &str) {
@@ -200,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 91] = [
const TESTS: [&str; 93] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -296,6 +296,8 @@ mod native_tests {
"reducel1",
"reducel2", // 89
"1l_lppool",
"lstm_large", // 91
"lstm_medium", // 92
];
const WASM_TESTS: [&str; 46] = [
@@ -534,7 +536,7 @@ mod native_tests {
}
});
seq!(N in 0..=90 {
seq!(N in 0..=92 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -622,7 +624,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
@@ -898,7 +900,7 @@ mod native_tests {
seq!(N in 0..=45 {
#(#[test_case(WASM_TESTS[N])])*
fn prove_and_verify_with_overflow_(test: &str) {
fn kzg_prove_and_verify_with_overflow_(test: &str) {
crate::native_tests::init_binary();
// crate::native_tests::init_wasm();
let test_dir = TempDir::new(test).unwrap();
@@ -907,11 +909,24 @@ mod native_tests {
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm", false);
// test_dir.close().unwrap();
test_dir.close().unwrap();
}
#(#[test_case(WASM_TESTS[N])])*
fn prove_and_verify_with_overflow_fixed_params_(test: &str) {
fn kzg_prove_and_verify_with_overflow_hashed_inputs_(test: &str) {
crate::native_tests::init_binary();
// crate::native_tests::init_wasm();
let test_dir = TempDir::new(test).unwrap();
env_logger::init();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
prove_and_verify(path, test.to_string(), "safe", "hashed", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm", false);
test_dir.close().unwrap();
}
#(#[test_case(WASM_TESTS[N])])*
fn kzg_prove_and_verify_with_overflow_fixed_params_(test: &str) {
crate::native_tests::init_binary();
// crate::native_tests::init_wasm();
let test_dir = TempDir::new(test).unwrap();
@@ -1359,7 +1374,7 @@ mod native_tests {
let witness = witness.clone();
let outputs = witness.outputs.clone();
// get values as i128
// get values as i64
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
@@ -1369,10 +1384,10 @@ mod native_tests {
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::zero()
} else {
i128_to_felt(
(felt_to_i128(*v) as f32
i64_to_felt(
(felt_to_i64(*v) as f32
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
as i128,
as i64,
)
};
@@ -1382,7 +1397,7 @@ mod native_tests {
})
.collect::<Vec<_>>();
// get values as i128
// get values as i64
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
@@ -1392,10 +1407,10 @@ mod native_tests {
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::from(2)
} else {
i128_to_felt(
(felt_to_i128(*v) as f32
i64_to_felt(
(felt_to_i64(*v) as f32
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
as i128,
as i64,
)
};
*v + perturbation
@@ -1970,7 +1985,7 @@ mod native_tests {
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment);
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -1,21 +1,29 @@
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
#[cfg(test)]
mod wasm32 {
use ezkl::circuit::modules::polycommit::PolyCommitChip;
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::PoseidonChip;
use ezkl::circuit::modules::Module;
use ezkl::graph::modules::POSEIDON_LEN_GRAPH;
use ezkl::graph::GraphWitness;
use ezkl::graph::GraphCircuit;
use ezkl::graph::{GraphSettings, GraphWitness};
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation,
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::commitment::CommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::Bn256;
use halo2curves::bn256::{Fr, G1Affine};
use snark_verifier::util::arithmetic::PrimeField;
use wasm_bindgen::JsError;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
use wasm_bindgen_test::*;
@@ -27,10 +35,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {
@@ -71,6 +98,45 @@ mod wasm32 {
assert_eq!(calldata, reference_calldata);
}
#[wasm_bindgen_test]
fn verify_kzg_commit() {
// create a vector of field elements Vec<Fr> and assign it to the message variable
let mut message: Vec<Fr> = vec![];
for i in 0..32 {
message.push(Fr::from(i as u64));
}
let message_ser = serde_json::to_vec(&message).unwrap();
let settings: GraphSettings = serde_json::from_slice(&SETTINGS).unwrap();
let mut reader = std::io::BufReader::new(SRS);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).unwrap();
let mut reader = std::io::BufReader::new(VK);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
settings.clone(),
)
.unwrap();
let commitment_ser = kzgCommit(
wasm_bindgen::Clamped(message_ser),
wasm_bindgen::Clamped(VK.to_vec()),
wasm_bindgen::Clamped(SETTINGS.to_vec()),
wasm_bindgen::Clamped(SRS.to_vec()),
)
.map_err(|_| "failed")
.unwrap();
let commitment: Vec<halo2curves::bn256::G1Affine> =
serde_json::from_slice(&commitment_ser[..]).unwrap();
let reference_commitment = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&params,
);
assert_eq!(commitment, reference_commitment);
}
#[wasm_bindgen_test]
async fn verify_field_serialization_roundtrip() {
for i in 0..32 {
@@ -84,10 +150,10 @@ mod wasm32 {
.unwrap();
assert_eq!(floating_point, (i as f64) / 4.0);
let integer: i128 =
let integer: i64 =
serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap())
.unwrap();
assert_eq!(integer, i as i128);
assert_eq!(integer, i as i64);
let hex_string = format!("{:?}", field_element.clone());
let returned_string: String = feltToBigEndian(clamped.clone())

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -24,8 +24,7 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE",
"commitment": "KZG"
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.